0
0
mirror of https://github.com/django/django.git synced 2024-12-01 15:42:04 +01:00

Refs #28305 -- Consolidated field referencing detection in migrations.

This moves all the field referencing resolution methods to shared
functions instead of duplicating efforts amongst state_forwards and
references methods.
This commit is contained in:
Simon Charette 2020-04-03 16:38:06 -04:00 committed by Mariusz Felisiak
parent 734fde7714
commit f5ede1cb6d
3 changed files with 118 additions and 95 deletions

View File

@ -3,9 +3,7 @@ from django.db.models import NOT_PROVIDED
from django.utils.functional import cached_property
from .base import Operation
from .utils import (
field_references_model, is_referenced_by_foreign_key, resolve_relation,
)
from .utils import field_is_referenced, field_references, get_references
class FieldOperation(Operation):
@ -33,9 +31,9 @@ class FieldOperation(Operation):
if name_lower == self.model_name_lower:
return True
if self.field:
return field_references_model(
return bool(field_references(
(app_label, self.model_name_lower), self.field, (app_label, name_lower)
)
))
return False
def references_field(self, model_name, name, app_label):
@ -47,20 +45,14 @@ class FieldOperation(Operation):
elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields:
return True
# Check if this operation remotely references the field.
if self.field:
model_tuple = (app_label, model_name_lower)
remote_field = self.field.remote_field
if remote_field:
if (resolve_relation(remote_field.model, app_label, self.model_name_lower) == model_tuple and
(not hasattr(self.field, 'to_fields') or
name in self.field.to_fields or None in self.field.to_fields)):
return True
through = getattr(remote_field, 'through', None)
if (through and resolve_relation(through, app_label, self.model_name_lower) == model_tuple and
(getattr(remote_field, 'through_fields', None) is None or
name in remote_field.through_fields)):
return True
return False
if self.field is None:
return False
return bool(field_references(
(app_label, self.model_name_lower),
self.field,
(app_label, model_name_lower),
name,
))
def reduce(self, operation, app_label):
return (
@ -236,7 +228,9 @@ class AlterField(FieldOperation):
# not referenced by a foreign key.
delay = (
not field.is_relation and
not is_referenced_by_foreign_key(state, self.model_name_lower, self.field, self.name)
not field_is_referenced(
state, (app_label, self.model_name_lower), (self.name, field),
)
)
state.reload_model(app_label, self.model_name_lower, delay=delay)
@ -305,17 +299,11 @@ class RenameField(FieldOperation):
model_state = state.models[app_label, self.model_name_lower]
# Rename the field
fields = model_state.fields
found = False
found = None
for index, (name, field) in enumerate(fields):
if not found and name == self.old_name:
fields[index] = (self.new_name, field)
found = True
# Delay rendering of relationships if it's not a relational
# field and not referenced by a foreign key.
delay = (
not field.is_relation and
not is_referenced_by_foreign_key(state, self.model_name_lower, field, self.name)
)
found = field
# Fix from_fields to refer to the new field.
from_fields = getattr(field, 'from_fields', None)
if from_fields:
@ -323,7 +311,7 @@ class RenameField(FieldOperation):
self.new_name if from_field_name == self.old_name else from_field_name
for from_field_name in from_fields
])
if not found:
if found is None:
raise FieldDoesNotExist(
"%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
)
@ -336,23 +324,21 @@ class RenameField(FieldOperation):
for together in options[option]
]
# Fix to_fields to refer to the new field.
model_tuple = app_label, self.model_name_lower
for (model_app_label, model_name), model_state in state.models.items():
for index, (name, field) in enumerate(model_state.fields):
remote_field = field.remote_field
if remote_field:
remote_model_tuple = resolve_relation(
remote_field.model, model_app_label, model_name
)
if remote_model_tuple == model_tuple:
if getattr(remote_field, 'field_name', None) == self.old_name:
remote_field.field_name = self.new_name
to_fields = getattr(field, 'to_fields', None)
if to_fields:
field.to_fields = tuple([
self.new_name if to_field_name == self.old_name else to_field_name
for to_field_name in to_fields
])
delay = True
references = get_references(
state, (app_label, self.model_name_lower), (self.old_name, found),
)
for *_, field, reference in references:
delay = False
if reference.to:
remote_field, to_fields = reference.to
if getattr(remote_field, 'field_name', None) == self.old_name:
remote_field.field_name = self.new_name
if to_fields:
field.to_fields = tuple([
self.new_name if to_field_name == self.old_name else to_field_name
for to_field_name in to_fields
])
state.reload_model(app_label, self.model_name_lower, delay=delay)
def database_forwards(self, app_label, schema_editor, from_state, to_state):

View File

@ -7,7 +7,7 @@ from django.utils.functional import cached_property
from .fields import (
AddField, AlterField, FieldOperation, RemoveField, RenameField,
)
from .utils import field_references_model, resolve_relation
from .utils import field_references, get_references, resolve_relation
def _check_for_duplicates(arg_name, objs):
@ -113,7 +113,7 @@ class CreateModel(ModelOperation):
# Check we have no FKs/M2Ms with it
for _name, field in self.fields:
if field_references_model((app_label, self.name_lower), field, reference_model_tuple):
if field_references((app_label, self.name_lower), field, reference_model_tuple):
return True
return False
@ -309,33 +309,19 @@ class RenameModel(ModelOperation):
# Repoint all fields pointing to the old model to the new one.
old_model_tuple = (app_label, self.old_name_lower)
new_remote_model = '%s.%s' % (app_label, self.new_name)
to_reload = []
for (model_app_label, model_name), model_state in state.models.items():
model_changed = False
for index, (name, field) in enumerate(model_state.fields):
changed_field = None
remote_field = field.remote_field
if remote_field:
remote_model_tuple = resolve_relation(
remote_field.model, model_app_label, model_name
)
if remote_model_tuple == old_model_tuple:
changed_field = field.clone()
changed_field.remote_field.model = new_remote_model
through_model = getattr(remote_field, 'through', None)
if through_model:
through_model_tuple = resolve_relation(
through_model, model_app_label, model_name
)
if through_model_tuple == old_model_tuple:
if changed_field is None:
changed_field = field.clone()
changed_field.remote_field.through = new_remote_model
if changed_field:
model_state.fields[index] = name, changed_field
model_changed = True
if model_changed:
to_reload.append((model_app_label, model_name))
to_reload = set()
for model_state, index, name, field, reference in get_references(state, old_model_tuple):
changed_field = None
if reference.to:
changed_field = field.clone()
changed_field.remote_field.model = new_remote_model
if reference.through:
if changed_field is None:
changed_field = field.clone()
changed_field.remote_field.through = new_remote_model
if changed_field:
model_state.fields[index] = name, changed_field
to_reload.add((model_state.app_label, model_state.name_lower))
# Reload models related to old model before removing the old model.
state.reload_models(to_reload, delay=True)
# Remove the old model.

View File

@ -1,17 +1,8 @@
from collections import namedtuple
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
def is_referenced_by_foreign_key(state, model_name_lower, field, field_name):
for state_app_label, state_model in state.models:
for _, f in state.models[state_app_label, state_model].fields:
if (f.related_model and
'%s.%s' % (state_app_label, model_name_lower) == f.related_model.lower() and
hasattr(f, 'to_fields')):
if (f.to_fields[0] is None and field.primary_key) or field_name in f.to_fields:
return True
return False
def resolve_relation(model, app_label=None, model_name=None):
"""
Turn a model class or model reference string and return a model tuple.
@ -38,13 +29,73 @@ def resolve_relation(model, app_label=None, model_name=None):
return model._meta.app_label, model._meta.model_name
def field_references_model(model_tuple, field, reference_model_tuple):
"""Return whether or not field references reference_model_tuple."""
FieldReference = namedtuple('FieldReference', 'to through')
def field_references(
model_tuple,
field,
reference_model_tuple,
reference_field_name=None,
reference_field=None,
):
"""
Return either False or a FieldReference if `field` references provided
context.
False positives can be returned if `reference_field_name` is provided
without `reference_field` because of the introspection limitation it
incurs. This should not be an issue when this function is used to determine
whether or not an optimization can take place.
"""
remote_field = field.remote_field
if remote_field:
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
return True
through = getattr(remote_field, 'through', None)
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
return True
return False
if not remote_field:
return False
references_to = None
references_through = None
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
to_fields = getattr(field, 'to_fields', None)
if (
reference_field_name is None or
# Unspecified to_field(s).
to_fields is None or
# Reference to primary key.
(None in to_fields and (reference_field is None or reference_field.primary_key)) or
# Reference to field.
reference_field_name in to_fields
):
references_to = (remote_field, to_fields)
through = getattr(remote_field, 'through', None)
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
through_fields = remote_field.through_fields
if (
reference_field_name is None or
# Unspecified through_fields.
through_fields is None or
# Reference to field.
reference_field_name in through_fields
):
references_through = (remote_field, through_fields)
if not (references_to or references_through):
return False
return FieldReference(references_to, references_through)
def get_references(state, model_tuple, field_tuple=()):
"""
Generator of (model_state, index, name, field, reference) referencing
provided context.
If field_tuple is provided only references to this particular field of
model_tuple will be generated.
"""
for state_model_tuple, model_state in state.models.items():
for index, (name, field) in enumerate(model_state.fields):
reference = field_references(state_model_tuple, field, model_tuple, *field_tuple)
if reference:
yield model_state, index, name, field, reference
def field_is_referenced(state, model_tuple, field_tuple):
"""Return whether `field_tuple` is referenced by any state models."""
return next(get_references(state, model_tuple, field_tuple), None) is not None