0
0
mirror of https://github.com/django/django.git synced 2024-11-21 19:09:18 +01:00

added support for parallel database setups

This commit is contained in:
leondaz 2024-10-12 17:00:45 +03:00
parent 9423f8b476
commit bcf3914528

View File

@ -6,6 +6,7 @@ import re
import sys
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from functools import wraps
from io import StringIO
@ -51,6 +52,7 @@ __all__ = (
"tag",
"requires_tz_support",
"setup_databases",
"TestDatabaseSetup",
"setup_test_environment",
"teardown_test_environment",
)
@ -170,6 +172,137 @@ def teardown_test_environment():
del mail.outbox
class TestDatabaseSetup:
def __init__(
self,
verbosity,
interactive,
time_keeper=None,
keepdb=False,
debug_sql=False,
parallel=0,
aliases=None,
serialized_aliases=None,
**kwargs,
):
self.verbosity = verbosity
self.interactive = interactive
self.time_keeper = time_keeper or NullTimeKeeper()
self.keepdb = keepdb
self.debug_sql = debug_sql
self.parallel = parallel
self.aliases = aliases
self.serialized_aliases = serialized_aliases
self.test_databases = {}
self.mirrored_aliases = {}
self.old_names = []
self.kwargs = kwargs
def setup(self):
"""Main method to set up test databases."""
self.test_databases, self.mirrored_aliases = get_unique_databases_and_mirrors(
self.aliases
)
# Process each unique database
for db_sig, (db_name, aliases) in self.test_databases.items():
# Extract the first alias outside the loop
first_alias = aliases[0]
remaining_aliases = aliases[1:]
# Create the test database for the first alias
self.create_test_database(first_alias, db_name)
# Clone the test database if parallel testing is enabled
if self.parallel > 1:
self.clone_databases_in_parallel(first_alias)
# Configure remaining aliases as mirrors
for alias in remaining_aliases:
self.configure_as_mirror(alias, first_alias)
# Configure any additional test mirrors
self.configure_test_mirrors()
# Enable debug SQL logging if required
if self.debug_sql:
self.enable_debug_sql()
return self.old_names
def create_test_database(self, alias, db_name):
"""Create the test database for the given alias."""
connection = connections[alias]
self.old_names.append((connection, db_name, True))
serialize_alias = (
self.serialized_aliases is None or alias in self.serialized_aliases
)
with self.time_keeper.timed(f" Creating '{alias}'"):
connection.creation.create_test_db(
verbosity=self.verbosity,
autoclobber=not self.interactive,
keepdb=self.keepdb,
serialize=serialize_alias,
)
def clone_databases_sequentially(self, alias):
"""Clone the test database sequentially for parallel test execution."""
for index in range(self.parallel):
self.clone_test_database(alias, index)
def clone_databases_in_parallel(self, alias):
"""Clone the test database in parallel threads for parallel test execution."""
with ThreadPoolExecutor(max_workers=self.parallel) as executor:
futures = [
executor.submit(self.clone_test_database, alias, index)
for index in range(self.parallel)
]
# Optionally, wait for all futures to complete
# This just ensures that any exceptions are raised
for future in futures:
future.result()
def clone_test_database(self, alias, index):
"""Clone the test database for parallel execution."""
connection = connections[alias]
# re-init the connection per thread
connection.close()
connection.connect()
with self.time_keeper.timed(f" Cloning '{alias}'"):
connection.creation.clone_test_db(
suffix=str(index + 1),
verbosity=self.verbosity,
keepdb=self.keepdb,
)
def configure_as_mirror(self, alias, source_alias):
"""Configure an alias as a mirror of the source alias."""
connection = connections[alias]
self.old_names.append(
(connection, None, False)
) # False indicates no creation needed
source_settings = connections[source_alias].settings_dict
connection.creation.set_as_test_mirror(source_settings)
def configure_test_mirrors(self):
"""Configure any additional test mirrors."""
for alias, mirror_alias in self.mirrored_aliases.items():
connection = connections[alias]
mirror_settings = connections[mirror_alias].settings_dict
connection.creation.set_as_test_mirror(mirror_settings)
def enable_debug_sql(self):
"""Enable debug SQL logging for all connections."""
for alias in connections:
connections[alias].force_debug_cursor = True
def setup_databases(
verbosity,
interactive,
@ -182,58 +315,17 @@ def setup_databases(
serialized_aliases=None,
**kwargs,
):
"""Create the test databases."""
if time_keeper is None:
time_keeper = NullTimeKeeper()
test_databases, mirrored_aliases = get_unique_databases_and_mirrors(aliases)
old_names = []
for db_name, aliases in test_databases.values():
first_alias = None
for alias in aliases:
connection = connections[alias]
old_names.append((connection, db_name, first_alias is None))
# Actually create the database for the first connection
if first_alias is None:
first_alias = alias
with time_keeper.timed(" Creating '%s'" % alias):
serialize_alias = (
serialized_aliases is None or alias in serialized_aliases
)
connection.creation.create_test_db(
verbosity=verbosity,
autoclobber=not interactive,
keepdb=keepdb,
serialize=serialize_alias,
)
if parallel > 1:
for index in range(parallel):
with time_keeper.timed(" Cloning '%s'" % alias):
connection.creation.clone_test_db(
suffix=str(index + 1),
verbosity=verbosity,
keepdb=keepdb,
)
# Configure all other connections as mirrors of the first one
else:
connections[alias].creation.set_as_test_mirror(
connections[first_alias].settings_dict
)
# Configure the test mirrors.
for alias, mirror_alias in mirrored_aliases.items():
connections[alias].creation.set_as_test_mirror(
connections[mirror_alias].settings_dict
)
if debug_sql:
for alias in connections:
connections[alias].force_debug_cursor = True
return old_names
return TestDatabaseSetup(
verbosity,
interactive,
time_keeper,
keepdb,
debug_sql,
parallel,
aliases,
serialized_aliases,
**kwargs,
).setup()
def iter_test_cases(tests):