From 41bd9d959ccdb1095b6662b903bb3cbd2a47087b Mon Sep 17 00:00:00 2001 From: Peter Bierma Date: Thu, 24 Oct 2024 12:51:45 -0400 Subject: [PATCH] gh-125864: Propagate `pickle.loads()` failures in `InterpreterPoolExecutor` (gh-125898) Authored-by: Peter Bierma --- Lib/concurrent/futures/interpreter.py | 3 ++- .../test_interpreter_pool.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/Lib/concurrent/futures/interpreter.py b/Lib/concurrent/futures/interpreter.py index fd7941adb76..d17688dc9d7 100644 --- a/Lib/concurrent/futures/interpreter.py +++ b/Lib/concurrent/futures/interpreter.py @@ -107,7 +107,8 @@ class WorkerContext(_thread.WorkerContext): @classmethod def _call_pickled(cls, pickled, resultsid): - fn, args, kwargs = pickle.loads(pickled) + with cls._capture_exc(resultsid): + fn, args, kwargs = pickle.loads(pickled) cls._call(fn, args, kwargs, resultsid) def __init__(self, initdata, shared=None): diff --git a/Lib/test/test_concurrent_futures/test_interpreter_pool.py b/Lib/test/test_concurrent_futures/test_interpreter_pool.py index 5264b1bb6e9..ea1512fc830 100644 --- a/Lib/test/test_concurrent_futures/test_interpreter_pool.py +++ b/Lib/test/test_concurrent_futures/test_interpreter_pool.py @@ -56,6 +56,16 @@ class InterpretersMixin(InterpreterPoolMixin): return r, w +class PickleShenanigans: + """Succeeds with pickle.dumps(), but fails with pickle.loads()""" + def __init__(self, value): + if value == 1: + raise RuntimeError("gotcha") + + def __reduce__(self): + return (self.__class__, (1,)) + + class InterpreterPoolExecutorTest( InterpretersMixin, ExecutorTest, BaseTestCase): @@ -279,6 +289,14 @@ class InterpreterPoolExecutorTest( self.assertEqual(len(executor._threads), 1) executor.shutdown(wait=True) + def test_pickle_errors_propagate(self): + # GH-125864: Pickle errors happen before the script tries to execute, so the + # queue used to wait infinitely. + + fut = self.executor.submit(PickleShenanigans(0)) + with self.assertRaisesRegex(RuntimeError, "gotcha"): + fut.result() + class AsyncioTest(InterpretersMixin, testasyncio_utils.TestCase):