diff --git a/django/test/utils.py b/django/test/utils.py index ddb85127dc..934d9bad48 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -6,6 +6,7 @@ import re import sys import time import warnings +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from functools import wraps from io import StringIO @@ -51,6 +52,7 @@ __all__ = ( "tag", "requires_tz_support", "setup_databases", + "TestDatabaseSetup", "setup_test_environment", "teardown_test_environment", ) @@ -170,6 +172,137 @@ def teardown_test_environment(): del mail.outbox +class TestDatabaseSetup: + def __init__( + self, + verbosity, + interactive, + time_keeper=None, + keepdb=False, + debug_sql=False, + parallel=0, + aliases=None, + serialized_aliases=None, + **kwargs, + ): + self.verbosity = verbosity + self.interactive = interactive + self.time_keeper = time_keeper or NullTimeKeeper() + self.keepdb = keepdb + self.debug_sql = debug_sql + self.parallel = parallel + self.aliases = aliases + self.serialized_aliases = serialized_aliases + self.test_databases = {} + self.mirrored_aliases = {} + self.old_names = [] + self.kwargs = kwargs + + def setup(self): + """Main method to set up test databases.""" + self.test_databases, self.mirrored_aliases = get_unique_databases_and_mirrors( + self.aliases + ) + + # Process each unique database + for db_sig, (db_name, aliases) in self.test_databases.items(): + # Extract the first alias outside the loop + first_alias = aliases[0] + remaining_aliases = aliases[1:] + + # Create the test database for the first alias + self.create_test_database(first_alias, db_name) + + # Clone the test database if parallel testing is enabled + if self.parallel > 1: + self.clone_databases_in_parallel(first_alias) + + # Configure remaining aliases as mirrors + for alias in remaining_aliases: + self.configure_as_mirror(alias, first_alias) + + # Configure any additional test mirrors + self.configure_test_mirrors() + + # Enable debug SQL logging if required + if self.debug_sql: + self.enable_debug_sql() + + return self.old_names + + def create_test_database(self, alias, db_name): + """Create the test database for the given alias.""" + connection = connections[alias] + self.old_names.append((connection, db_name, True)) + + serialize_alias = ( + self.serialized_aliases is None or alias in self.serialized_aliases + ) + + with self.time_keeper.timed(f" Creating '{alias}'"): + connection.creation.create_test_db( + verbosity=self.verbosity, + autoclobber=not self.interactive, + keepdb=self.keepdb, + serialize=serialize_alias, + ) + + def clone_databases_sequentially(self, alias): + """Clone the test database sequentially for parallel test execution.""" + for index in range(self.parallel): + self.clone_test_database(alias, index) + + def clone_databases_in_parallel(self, alias): + """Clone the test database in parallel threads for parallel test execution.""" + with ThreadPoolExecutor(max_workers=self.parallel) as executor: + futures = [ + executor.submit(self.clone_test_database, alias, index) + for index in range(self.parallel) + ] + + # Optionally, wait for all futures to complete + # This just ensures that any exceptions are raised + for future in futures: + future.result() + + def clone_test_database(self, alias, index): + """Clone the test database for parallel execution.""" + connection = connections[alias] + + # re-init the connection per thread + connection.close() + connection.connect() + + with self.time_keeper.timed(f" Cloning '{alias}'"): + connection.creation.clone_test_db( + suffix=str(index + 1), + verbosity=self.verbosity, + keepdb=self.keepdb, + ) + + def configure_as_mirror(self, alias, source_alias): + """Configure an alias as a mirror of the source alias.""" + connection = connections[alias] + self.old_names.append( + (connection, None, False) + ) # False indicates no creation needed + + source_settings = connections[source_alias].settings_dict + connection.creation.set_as_test_mirror(source_settings) + + def configure_test_mirrors(self): + """Configure any additional test mirrors.""" + for alias, mirror_alias in self.mirrored_aliases.items(): + connection = connections[alias] + mirror_settings = connections[mirror_alias].settings_dict + connection.creation.set_as_test_mirror(mirror_settings) + + def enable_debug_sql(self): + """Enable debug SQL logging for all connections.""" + for alias in connections: + connections[alias].force_debug_cursor = True + + def setup_databases( verbosity, interactive, @@ -182,58 +315,17 @@ def setup_databases( serialized_aliases=None, **kwargs, ): - """Create the test databases.""" - if time_keeper is None: - time_keeper = NullTimeKeeper() - - test_databases, mirrored_aliases = get_unique_databases_and_mirrors(aliases) - - old_names = [] - - for db_name, aliases in test_databases.values(): - first_alias = None - for alias in aliases: - connection = connections[alias] - old_names.append((connection, db_name, first_alias is None)) - - # Actually create the database for the first connection - if first_alias is None: - first_alias = alias - with time_keeper.timed(" Creating '%s'" % alias): - serialize_alias = ( - serialized_aliases is None or alias in serialized_aliases - ) - connection.creation.create_test_db( - verbosity=verbosity, - autoclobber=not interactive, - keepdb=keepdb, - serialize=serialize_alias, - ) - if parallel > 1: - for index in range(parallel): - with time_keeper.timed(" Cloning '%s'" % alias): - connection.creation.clone_test_db( - suffix=str(index + 1), - verbosity=verbosity, - keepdb=keepdb, - ) - # Configure all other connections as mirrors of the first one - else: - connections[alias].creation.set_as_test_mirror( - connections[first_alias].settings_dict - ) - - # Configure the test mirrors. - for alias, mirror_alias in mirrored_aliases.items(): - connections[alias].creation.set_as_test_mirror( - connections[mirror_alias].settings_dict - ) - - if debug_sql: - for alias in connections: - connections[alias].force_debug_cursor = True - - return old_names + return TestDatabaseSetup( + verbosity, + interactive, + time_keeper, + keepdb, + debug_sql, + parallel, + aliases, + serialized_aliases, + **kwargs, + ).setup() def iter_test_cases(tests):