From 086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sun, 8 May 2022 17:49:09 +0300 Subject: [PATCH] bpo-45046: Support context managers in unittest (GH-28045) Add methods enterContext() and enterClassContext() in TestCase. Add method enterAsyncContext() in IsolatedAsyncioTestCase. Add function enterModuleContext(). --- Doc/library/unittest.rst | 42 +++++++ Doc/whatsnew/3.11.rst | 12 ++ Lib/distutils/tests/test_build_ext.py | 4 +- Lib/test/test__osx_support.py | 3 +- Lib/test/test_argparse.py | 6 +- Lib/test/test_getopt.py | 6 +- Lib/test/test_gettext.py | 7 +- Lib/test/test_global.py | 11 +- Lib/test/test_importlib/source/test_finder.py | 19 +-- .../test_importlib/test_namespace_pkgs.py | 7 +- Lib/test/test_logging.py | 4 +- Lib/test/test_nntplib.py | 3 +- Lib/test/test_peg_generator/test_c_parser.py | 4 +- Lib/test/test_poll.py | 3 +- Lib/test/test_posix.py | 12 +- Lib/test/test_set.py | 6 +- Lib/test/test_socket.py | 4 +- Lib/test/test_ssl.py | 6 +- Lib/test/test_tempfile.py | 6 +- Lib/test/test_urllib.py | 7 +- Lib/unittest/__init__.py | 5 +- Lib/unittest/async_case.py | 20 ++++ Lib/unittest/case.py | 32 +++++ Lib/unittest/test/test_async_case.py | 53 +++++++++ Lib/unittest/test/test_runner.py | 110 ++++++++++++++++++ .../2021-08-29-19-59-16.bpo-45046.eGq0NC.rst | 7 ++ 26 files changed, 307 insertions(+), 92 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst diff --git a/Doc/library/unittest.rst b/Doc/library/unittest.rst index 9b8b75acce5..f6bcba06d90 100644 --- a/Doc/library/unittest.rst +++ b/Doc/library/unittest.rst @@ -1495,6 +1495,16 @@ Test cases .. versionadded:: 3.1 + .. method:: enterContext(cm) + + Enter the supplied :term:`context manager`. If successful, also + add its :meth:`~object.__exit__` method as a cleanup function by + :meth:`addCleanup` and return the result of the + :meth:`~object.__enter__` method. + + .. versionadded:: 3.11 + + .. method:: doCleanups() This method is called unconditionally after :meth:`tearDown`, or @@ -1510,6 +1520,7 @@ Test cases .. versionadded:: 3.1 + .. classmethod:: addClassCleanup(function, /, *args, **kwargs) Add a function to be called after :meth:`tearDownClass` to cleanup @@ -1524,6 +1535,16 @@ Test cases .. versionadded:: 3.8 + .. classmethod:: enterClassContext(cm) + + Enter the supplied :term:`context manager`. If successful, also + add its :meth:`~object.__exit__` method as a cleanup function by + :meth:`addClassCleanup` and return the result of the + :meth:`~object.__enter__` method. + + .. versionadded:: 3.11 + + .. classmethod:: doClassCleanups() This method is called unconditionally after :meth:`tearDownClass`, or @@ -1571,6 +1592,16 @@ Test cases This method accepts a coroutine that can be used as a cleanup function. + .. coroutinemethod:: enterAsyncContext(cm) + + Enter the supplied :term:`asynchronous context manager`. If successful, + also add its :meth:`~object.__aexit__` method as a cleanup function by + :meth:`addAsyncCleanup` and return the result of the + :meth:`~object.__aenter__` method. + + .. versionadded:: 3.11 + + .. method:: run(result=None) Sets up a new event loop to run the test, collecting the result into @@ -2465,6 +2496,16 @@ To add cleanup code that must be run even in the case of an exception, use .. versionadded:: 3.8 +.. classmethod:: enterModuleContext(cm) + + Enter the supplied :term:`context manager`. If successful, also + add its :meth:`~object.__exit__` method as a cleanup function by + :func:`addModuleCleanup` and return the result of the + :meth:`~object.__enter__` method. + + .. versionadded:: 3.11 + + .. function:: doModuleCleanups() This function is called unconditionally after :func:`tearDownModule`, or @@ -2480,6 +2521,7 @@ To add cleanup code that must be run even in the case of an exception, use .. versionadded:: 3.8 + Signal Handling --------------- diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst index c4e8e6f9a10..defaeebc773 100644 --- a/Doc/whatsnew/3.11.rst +++ b/Doc/whatsnew/3.11.rst @@ -758,6 +758,18 @@ unicodedata * The Unicode database has been updated to version 14.0.0. (:issue:`45190`). +unittest +-------- + +* Added methods :meth:`~unittest.TestCase.enterContext` and + :meth:`~unittest.TestCase.enterClassContext` of class + :class:`~unittest.TestCase`, method + :meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of + class :class:`~unittest.IsolatedAsyncioTestCase` and function + :func:`unittest.enterModuleContext`. + (Contributed by Serhiy Storchaka in :issue:`45046`.) + + venv ---- diff --git a/Lib/distutils/tests/test_build_ext.py b/Lib/distutils/tests/test_build_ext.py index 031897bd273..4ebeafecef0 100644 --- a/Lib/distutils/tests/test_build_ext.py +++ b/Lib/distutils/tests/test_build_ext.py @@ -41,9 +41,7 @@ class BuildExtTestCase(TempdirManager, # bpo-30132: On Windows, a .pdb file may be created in the current # working directory. Create a temporary working directory to cleanup # everything at the end of the test. - change_cwd = os_helper.change_cwd(self.tmp_dir) - change_cwd.__enter__() - self.addCleanup(change_cwd.__exit__, None, None, None) + self.enterContext(os_helper.change_cwd(self.tmp_dir)) def tearDown(self): import site diff --git a/Lib/test/test__osx_support.py b/Lib/test/test__osx_support.py index 907ae27d529..4a14cb35213 100644 --- a/Lib/test/test__osx_support.py +++ b/Lib/test/test__osx_support.py @@ -19,8 +19,7 @@ class Test_OSXSupport(unittest.TestCase): self.maxDiff = None self.prog_name = 'bogus_program_xxxx' self.temp_path_dir = os.path.abspath(os.getcwd()) - self.env = os_helper.EnvironmentVarGuard() - self.addCleanup(self.env.__exit__) + self.env = self.enterContext(os_helper.EnvironmentVarGuard()) for cv in ('CFLAGS', 'LDFLAGS', 'CPPFLAGS', 'BASECFLAGS', 'BLDSHARED', 'LDSHARED', 'CC', 'CXX', 'PY_CFLAGS', 'PY_LDFLAGS', 'PY_CPPFLAGS', diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py index 8509deb93f1..273db45c00f 100644 --- a/Lib/test/test_argparse.py +++ b/Lib/test/test_argparse.py @@ -41,9 +41,8 @@ class TestCase(unittest.TestCase): # The tests assume that line wrapping occurs at 80 columns, but this # behaviour can be overridden by setting the COLUMNS environment # variable. To ensure that this width is used, set COLUMNS to 80. - env = os_helper.EnvironmentVarGuard() + env = self.enterContext(os_helper.EnvironmentVarGuard()) env['COLUMNS'] = '80' - self.addCleanup(env.__exit__) class TempDirMixin(object): @@ -3428,9 +3427,8 @@ class TestShortColumns(HelpTestCase): but we don't want any exceptions thrown in such cases. Only ugly representation. ''' def setUp(self): - env = os_helper.EnvironmentVarGuard() + env = self.enterContext(os_helper.EnvironmentVarGuard()) env.set("COLUMNS", '15') - self.addCleanup(env.__exit__) parser_signature = TestHelpBiggerOptionals.parser_signature argument_signatures = TestHelpBiggerOptionals.argument_signatures diff --git a/Lib/test/test_getopt.py b/Lib/test/test_getopt.py index 9261276ebb9..64b9ce01e05 100644 --- a/Lib/test/test_getopt.py +++ b/Lib/test/test_getopt.py @@ -11,14 +11,10 @@ sentinel = object() class GetoptTests(unittest.TestCase): def setUp(self): - self.env = EnvironmentVarGuard() + self.env = self.enterContext(EnvironmentVarGuard()) if "POSIXLY_CORRECT" in self.env: del self.env["POSIXLY_CORRECT"] - def tearDown(self): - self.env.__exit__() - del self.env - def assertError(self, *args, **kwargs): self.assertRaises(getopt.GetoptError, *args, **kwargs) diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py index 467652a41f0..1608d1b18e9 100644 --- a/Lib/test/test_gettext.py +++ b/Lib/test/test_gettext.py @@ -117,6 +117,7 @@ MMOFILE = os.path.join(LOCALEDIR, 'metadata.mo') class GettextBaseTest(unittest.TestCase): def setUp(self): + self.addCleanup(os_helper.rmtree, os.path.split(LOCALEDIR)[0]) if not os.path.isdir(LOCALEDIR): os.makedirs(LOCALEDIR) with open(MOFILE, 'wb') as fp: @@ -129,14 +130,10 @@ class GettextBaseTest(unittest.TestCase): fp.write(base64.decodebytes(UMO_DATA)) with open(MMOFILE, 'wb') as fp: fp.write(base64.decodebytes(MMO_DATA)) - self.env = os_helper.EnvironmentVarGuard() + self.env = self.enterContext(os_helper.EnvironmentVarGuard()) self.env['LANGUAGE'] = 'xx' gettext._translations.clear() - def tearDown(self): - self.env.__exit__() - del self.env - os_helper.rmtree(os.path.split(LOCALEDIR)[0]) GNU_MO_DATA_ISSUE_17898 = b'''\ 3hIElQAAAAABAAAAHAAAACQAAAAAAAAAAAAAAAAAAAAsAAAAggAAAC0AAAAAUGx1cmFsLUZvcm1z diff --git a/Lib/test/test_global.py b/Lib/test/test_global.py index d0bde3fd040..f5b38c25ea0 100644 --- a/Lib/test/test_global.py +++ b/Lib/test/test_global.py @@ -9,14 +9,9 @@ import warnings class GlobalTests(unittest.TestCase): def setUp(self): - self._warnings_manager = check_warnings() - self._warnings_manager.__enter__() + self.enterContext(check_warnings()) warnings.filterwarnings("error", module="") - def tearDown(self): - self._warnings_manager.__exit__(None, None, None) - - def test1(self): prog_text_1 = """\ def wrong1(): @@ -54,9 +49,7 @@ x = 2 def setUpModule(): - cm = warnings.catch_warnings() - cm.__enter__() - unittest.addModuleCleanup(cm.__exit__, None, None, None) + unittest.enterModuleContext(warnings.catch_warnings()) warnings.filterwarnings("error", module="") diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py index 6a23e9d50f6..bed9d56dca8 100644 --- a/Lib/test/test_importlib/source/test_finder.py +++ b/Lib/test/test_importlib/source/test_finder.py @@ -157,21 +157,12 @@ class FinderTests(abc.FinderTests): def test_no_read_directory(self): # Issue #16730 tempdir = tempfile.TemporaryDirectory() + self.enterContext(tempdir) + # Since we muck with the permissions, we want to set them back to + # their original values to make sure the directory can be properly + # cleaned up. original_mode = os.stat(tempdir.name).st_mode - def cleanup(tempdir): - """Cleanup function for the temporary directory. - - Since we muck with the permissions, we want to set them back to - their original values to make sure the directory can be properly - cleaned up. - - """ - os.chmod(tempdir.name, original_mode) - # If this is not explicitly called then the __del__ method is used, - # but since already mucking around might as well explicitly clean - # up. - tempdir.__exit__(None, None, None) - self.addCleanup(cleanup, tempdir) + self.addCleanup(os.chmod, tempdir.name, original_mode) os.chmod(tempdir.name, stat.S_IWUSR | stat.S_IXUSR) finder = self.get_finder(tempdir.name) found = self._find(finder, 'doesnotexist') diff --git a/Lib/test/test_importlib/test_namespace_pkgs.py b/Lib/test/test_importlib/test_namespace_pkgs.py index 2ea41b7a4c5..cd08498545e 100644 --- a/Lib/test/test_importlib/test_namespace_pkgs.py +++ b/Lib/test/test_importlib/test_namespace_pkgs.py @@ -65,12 +65,7 @@ class NamespacePackageTest(unittest.TestCase): self.resolved_paths = [ os.path.join(self.root, path) for path in self.paths ] - self.ctx = namespace_tree_context(path=self.resolved_paths) - self.ctx.__enter__() - - def tearDown(self): - # TODO: will we ever want to pass exc_info to __exit__? - self.ctx.__exit__(None, None, None) + self.enterContext(namespace_tree_context(path=self.resolved_paths)) class SingleNamespacePackage(NamespacePackageTest): diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 5d4ddedd059..e69afae484a 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -5650,9 +5650,7 @@ class MiscTestCase(unittest.TestCase): # why the test does this, but in any case we save the current locale # first and restore it at the end. def setUpModule(): - cm = support.run_with_locale('LC_ALL', '') - cm.__enter__() - unittest.addModuleCleanup(cm.__exit__, None, None, None) + unittest.enterModuleContext(support.run_with_locale('LC_ALL', '')) if __name__ == "__main__": diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py index 9812c055193..31a02f86abb 100644 --- a/Lib/test/test_nntplib.py +++ b/Lib/test/test_nntplib.py @@ -1593,8 +1593,7 @@ class LocalServerTests(unittest.TestCase): self.background.start() self.addCleanup(self.background.join) - self.nntp = NNTP(socket_helper.HOST, port, usenetrc=False).__enter__() - self.addCleanup(self.nntp.__exit__, None, None, None) + self.nntp = self.enterContext(NNTP(socket_helper.HOST, port, usenetrc=False)) def run_server(self, sock): # Could be generalized to handle more commands in separate methods diff --git a/Lib/test/test_peg_generator/test_c_parser.py b/Lib/test/test_peg_generator/test_c_parser.py index 13b83a9db9e..d25bc112cfd 100644 --- a/Lib/test/test_peg_generator/test_c_parser.py +++ b/Lib/test/test_peg_generator/test_c_parser.py @@ -96,9 +96,7 @@ class TestCParser(unittest.TestCase): self.skipTest("The %r command is not found" % cmd) self.old_cwd = os.getcwd() self.tmp_path = tempfile.mkdtemp(dir=self.tmp_base) - change_cwd = os_helper.change_cwd(self.tmp_path) - change_cwd.__enter__() - self.addCleanup(change_cwd.__exit__, None, None, None) + self.enterContext(os_helper.change_cwd(self.tmp_path)) def tearDown(self): os.chdir(self.old_cwd) diff --git a/Lib/test/test_poll.py b/Lib/test/test_poll.py index 7d542b5cfd7..02165a0244d 100644 --- a/Lib/test/test_poll.py +++ b/Lib/test/test_poll.py @@ -128,8 +128,7 @@ class PollTests(unittest.TestCase): cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done' proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, bufsize=0) - proc.__enter__() - self.addCleanup(proc.__exit__, None, None, None) + self.enterContext(proc) p = proc.stdout pollster = select.poll() pollster.register( p, select.POLLIN ) diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py index f44b8d0403f..28e5e90297e 100644 --- a/Lib/test/test_posix.py +++ b/Lib/test/test_posix.py @@ -53,19 +53,13 @@ class PosixTester(unittest.TestCase): def setUp(self): # create empty file + self.addCleanup(os_helper.unlink, os_helper.TESTFN) with open(os_helper.TESTFN, "wb"): pass - self.teardown_files = [ os_helper.TESTFN ] - self._warnings_manager = warnings_helper.check_warnings() - self._warnings_manager.__enter__() + self.enterContext(warnings_helper.check_warnings()) warnings.filterwarnings('ignore', '.* potential security risk .*', RuntimeWarning) - def tearDown(self): - for teardown_file in self.teardown_files: - os_helper.unlink(teardown_file) - self._warnings_manager.__exit__(None, None, None) - def testNoArgFunctions(self): # test posix functions which take no arguments and have # no side-effects which we need to cleanup (e.g., fork, wait, abort) @@ -973,8 +967,8 @@ class PosixTester(unittest.TestCase): self.assertTrue(hasattr(testfn_st, 'st_flags')) + self.addCleanup(os_helper.unlink, _DUMMY_SYMLINK) os.symlink(os_helper.TESTFN, _DUMMY_SYMLINK) - self.teardown_files.append(_DUMMY_SYMLINK) dummy_symlink_st = os.lstat(_DUMMY_SYMLINK) def chflags_nofollow(path, flags): diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 3b57517a861..43f23dbbf9b 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -1022,8 +1022,7 @@ class TestBasicOpsBytes(TestBasicOps, unittest.TestCase): class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): def setUp(self): - self._warning_filters = warnings_helper.check_warnings() - self._warning_filters.__enter__() + self.enterContext(warnings_helper.check_warnings()) warnings.simplefilter('ignore', BytesWarning) self.case = "string and bytes set" self.values = ["a", "b", b"a", b"b"] @@ -1031,9 +1030,6 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): self.dup = set(self.values) self.length = 4 - def tearDown(self): - self._warning_filters.__exit__(None, None, None) - def test_repr(self): self.check_repr_against_values() diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 613363722cf..1aaa9e44f90 100755 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -338,9 +338,7 @@ class ThreadableTest: self.server_ready.set() def _setUp(self): - self.wait_threads = threading_helper.wait_threads_exit() - self.wait_threads.__enter__() - self.addCleanup(self.wait_threads.__exit__, None, None, None) + self.enterContext(threading_helper.wait_threads_exit()) self.server_ready = threading.Event() self.client_ready = threading.Event() diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 0eb8d18b356..fed76378726 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -1999,9 +1999,8 @@ class SimpleBackgroundTests(unittest.TestCase): self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) self.server_context.load_cert_chain(SIGNED_CERTFILE) server = ThreadedEchoServer(context=self.server_context) + self.enterContext(server) self.server_addr = (HOST, server.port) - server.__enter__() - self.addCleanup(server.__exit__, None, None, None) def test_connect(self): with test_wrap_socket(socket.socket(socket.AF_INET), @@ -3713,8 +3712,7 @@ class ThreadedTests(unittest.TestCase): def test_recv_zero(self): server = ThreadedEchoServer(CERTFILE) - server.__enter__() - self.addCleanup(server.__exit__, None, None) + self.enterContext(server) s = socket.create_connection((HOST, server.port)) self.addCleanup(s.close) s = test_wrap_socket(s, suppress_ragged_eofs=False) diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py index a05f3c84ccf..f056e5ccb17 100644 --- a/Lib/test/test_tempfile.py +++ b/Lib/test/test_tempfile.py @@ -90,14 +90,10 @@ class BaseTestCase(unittest.TestCase): b_check = re.compile(br"^[a-z0-9_-]{8}$") def setUp(self): - self._warnings_manager = warnings_helper.check_warnings() - self._warnings_manager.__enter__() + self.enterContext(warnings_helper.check_warnings()) warnings.filterwarnings("ignore", category=RuntimeWarning, message="mktemp", module=__name__) - def tearDown(self): - self._warnings_manager.__exit__(None, None, None) - def nameCheck(self, name, dir, pre, suf): (ndir, nbase) = os.path.split(name) npre = nbase[:len(pre)] diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py index 82f1d9dc2e7..bc6e74c291a 100644 --- a/Lib/test/test_urllib.py +++ b/Lib/test/test_urllib.py @@ -232,17 +232,12 @@ class ProxyTests(unittest.TestCase): def setUp(self): # Records changes to env vars - self.env = os_helper.EnvironmentVarGuard() + self.env = self.enterContext(os_helper.EnvironmentVarGuard()) # Delete all proxy related env vars for k in list(os.environ): if 'proxy' in k.lower(): self.env.unset(k) - def tearDown(self): - # Restore all proxy related env vars - self.env.__exit__() - del self.env - def test_getproxies_environment_keep_no_proxies(self): self.env.set('NO_PROXY', 'localhost') proxies = urllib.request.getproxies_environment() diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py index eda951ce73e..005d23f6d00 100644 --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -49,7 +49,7 @@ __all__ = ['TestResult', 'TestCase', 'IsolatedAsyncioTestCase', 'TestSuite', 'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless', 'expectedFailure', 'TextTestResult', 'installHandler', 'registerResult', 'removeResult', 'removeHandler', - 'addModuleCleanup', 'doModuleCleanups'] + 'addModuleCleanup', 'doModuleCleanups', 'enterModuleContext'] # Expose obsolete functions for backwards compatibility # bpo-5846: Deprecated in Python 3.11, scheduled for removal in Python 3.13. @@ -59,7 +59,8 @@ __unittest = True from .result import TestResult from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip, - skipIf, skipUnless, expectedFailure, doModuleCleanups) + skipIf, skipUnless, expectedFailure, doModuleCleanups, + enterModuleContext) from .suite import BaseTestSuite, TestSuite from .loader import TestLoader, defaultTestLoader from .main import TestProgram, main diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py index 85b938fb293..a90eed98f87 100644 --- a/Lib/unittest/async_case.py +++ b/Lib/unittest/async_case.py @@ -58,6 +58,26 @@ class IsolatedAsyncioTestCase(TestCase): # 3. Regular "def func()" that returns awaitable object self.addCleanup(*(func, *args), **kwargs) + async def enterAsyncContext(self, cm): + """Enters the supplied asynchronous context manager. + + If successful, also adds its __aexit__ method as a cleanup + function and returns the result of the __aenter__ method. + """ + # We look up the special methods on the type to match the with + # statement. + cls = type(cm) + try: + enter = cls.__aenter__ + exit = cls.__aexit__ + except AttributeError: + raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does " + f"not support the asynchronous context manager protocol" + ) from None + result = await enter(cm) + self.addAsyncCleanup(exit, cm, None, None, None) + return result + def _callSetUp(self): self._asyncioTestContext.run(self.setUp) self._callAsync(self.asyncSetUp) diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 55770c06d7c..ffc8f19ddd3 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -102,12 +102,31 @@ def _id(obj): return obj +def _enter_context(cm, addcleanup): + # We look up the special methods on the type to match the with + # statement. + cls = type(cm) + try: + enter = cls.__enter__ + exit = cls.__exit__ + except AttributeError: + raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does " + f"not support the context manager protocol") from None + result = enter(cm) + addcleanup(exit, cm, None, None, None) + return result + + _module_cleanups = [] def addModuleCleanup(function, /, *args, **kwargs): """Same as addCleanup, except the cleanup items are called even if setUpModule fails (unlike tearDownModule).""" _module_cleanups.append((function, args, kwargs)) +def enterModuleContext(cm): + """Same as enterContext, but module-wide.""" + return _enter_context(cm, addModuleCleanup) + def doModuleCleanups(): """Execute all module cleanup functions. Normally called for you after @@ -426,12 +445,25 @@ class TestCase(object): Cleanup items are called even if setUp fails (unlike tearDown).""" self._cleanups.append((function, args, kwargs)) + def enterContext(self, cm): + """Enters the supplied context manager. + + If successful, also adds its __exit__ method as a cleanup + function and returns the result of the __enter__ method. + """ + return _enter_context(cm, self.addCleanup) + @classmethod def addClassCleanup(cls, function, /, *args, **kwargs): """Same as addCleanup, except the cleanup items are called even if setUpClass fails (unlike tearDownClass).""" cls._class_cleanups.append((function, args, kwargs)) + @classmethod + def enterClassContext(cls, cm): + """Same as enterContext, but class-wide.""" + return _enter_context(cm, cls.addClassCleanup) + def setUp(self): "Hook method for setting up the test fixture before exercising it." pass diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py index 1b910a44eea..beadcac070b 100644 --- a/Lib/unittest/test/test_async_case.py +++ b/Lib/unittest/test/test_async_case.py @@ -14,6 +14,29 @@ def tearDownModule(): asyncio.set_event_loop_policy(None) +class TestCM: + def __init__(self, ordering, enter_result=None): + self.ordering = ordering + self.enter_result = enter_result + + async def __aenter__(self): + self.ordering.append('enter') + return self.enter_result + + async def __aexit__(self, *exc_info): + self.ordering.append('exit') + + +class LacksEnterAndExit: + pass +class LacksEnter: + async def __aexit__(self, *exc_info): + pass +class LacksExit: + async def __aenter__(self): + pass + + VAR = contextvars.ContextVar('VAR', default=()) @@ -337,6 +360,36 @@ class TestAsyncCase(unittest.TestCase): output = test.run() self.assertTrue(cancelled) + def test_enterAsyncContext(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def test_func(slf): + slf.addAsyncCleanup(events.append, 'cleanup1') + cm = TestCM(events, 42) + self.assertEqual(await slf.enterAsyncContext(cm), 42) + slf.addAsyncCleanup(events.append, 'cleanup2') + events.append('test') + + test = Test('test_func') + output = test.run() + self.assertTrue(output.wasSuccessful(), output) + self.assertEqual(events, ['enter', 'test', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterAsyncContext_arg_errors(self): + class Test(unittest.IsolatedAsyncioTestCase): + async def test_func(slf): + with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): + await slf.enterAsyncContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): + await slf.enterAsyncContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): + await slf.enterAsyncContext(LacksExit()) + + test = Test('test_func') + output = test.run() + self.assertTrue(output.wasSuccessful()) + def test_debug_cleanup_same_loop(self): class Test(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py index 18062ae5a58..d3488b40e82 100644 --- a/Lib/unittest/test/test_runner.py +++ b/Lib/unittest/test/test_runner.py @@ -46,6 +46,29 @@ def cleanup(ordering, blowUp=False): raise Exception('CleanUpExc') +class TestCM: + def __init__(self, ordering, enter_result=None): + self.ordering = ordering + self.enter_result = enter_result + + def __enter__(self): + self.ordering.append('enter') + return self.enter_result + + def __exit__(self, *exc_info): + self.ordering.append('exit') + + +class LacksEnterAndExit: + pass +class LacksEnter: + def __exit__(self, *exc_info): + pass +class LacksExit: + def __enter__(self): + pass + + class TestCleanUp(unittest.TestCase): def testCleanUp(self): class TestableTest(unittest.TestCase): @@ -173,6 +196,39 @@ class TestCleanUp(unittest.TestCase): self.assertEqual(ordering, ['setUp', 'test', 'tearDown', 'cleanup1', 'cleanup2']) + def test_enterContext(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + test = TestableTest('testNothing') + cleanups = [] + + test.addCleanup(cleanups.append, 'cleanup1') + cm = TestCM(cleanups, 42) + self.assertEqual(test.enterContext(cm), 42) + test.addCleanup(cleanups.append, 'cleanup2') + + self.assertTrue(test.doCleanups()) + self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterContext_arg_errors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + test = TestableTest('testNothing') + + with self.assertRaisesRegex(TypeError, 'the context manager'): + test.enterContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + test.enterContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + test.enterContext(LacksExit()) + + self.assertEqual(test._cleanups, []) + + class TestClassCleanup(unittest.TestCase): def test_addClassCleanUp(self): class TestableTest(unittest.TestCase): @@ -451,6 +507,35 @@ class TestClassCleanup(unittest.TestCase): self.assertEqual(ordering, ['setUpClass', 'test', 'tearDownClass', 'cleanup_good']) + def test_enterClassContext(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + cleanups = [] + + TestableTest.addClassCleanup(cleanups.append, 'cleanup1') + cm = TestCM(cleanups, 42) + self.assertEqual(TestableTest.enterClassContext(cm), 42) + TestableTest.addClassCleanup(cleanups.append, 'cleanup2') + + TestableTest.doClassCleanups() + self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterClassContext_arg_errors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + with self.assertRaisesRegex(TypeError, 'the context manager'): + TestableTest.enterClassContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + TestableTest.enterClassContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + TestableTest.enterClassContext(LacksExit()) + + self.assertEqual(TestableTest._class_cleanups, []) + class TestModuleCleanUp(unittest.TestCase): def test_add_and_do_ModuleCleanup(self): @@ -1000,6 +1085,31 @@ class TestModuleCleanUp(unittest.TestCase): 'cleanup2', 'setUp2', 'test2', 'tearDown2', 'cleanup3', 'tearDownModule', 'cleanup1']) + def test_enterModuleContext(self): + cleanups = [] + + unittest.addModuleCleanup(cleanups.append, 'cleanup1') + cm = TestCM(cleanups, 42) + self.assertEqual(unittest.enterModuleContext(cm), 42) + unittest.addModuleCleanup(cleanups.append, 'cleanup2') + + unittest.case.doModuleCleanups() + self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterModuleContext_arg_errors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + with self.assertRaisesRegex(TypeError, 'the context manager'): + unittest.enterModuleContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + unittest.enterModuleContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + unittest.enterModuleContext(LacksExit()) + + self.assertEqual(unittest.case._module_cleanups, []) + class Test_TextTestRunner(unittest.TestCase): """Tests for TextTestRunner.""" diff --git a/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst b/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst new file mode 100644 index 00000000000..8072afaf445 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst @@ -0,0 +1,7 @@ +Add support of context managers in :mod:`unittest`: methods +:meth:`~unittest.TestCase.enterContext` and +:meth:`~unittest.TestCase.enterClassContext` of class +:class:`~unittest.TestCase`, method +:meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of class +:class:`~unittest.IsolatedAsyncioTestCase` and function +:func:`unittest.enterModuleContext`.