diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index 568515d8ac..ecb2af8a28 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -3,9 +3,7 @@ from django.db.models import NOT_PROVIDED from django.utils.functional import cached_property from .base import Operation -from .utils import ( - field_references_model, is_referenced_by_foreign_key, resolve_relation, -) +from .utils import field_is_referenced, field_references, get_references class FieldOperation(Operation): @@ -33,9 +31,9 @@ class FieldOperation(Operation): if name_lower == self.model_name_lower: return True if self.field: - return field_references_model( + return bool(field_references( (app_label, self.model_name_lower), self.field, (app_label, name_lower) - ) + )) return False def references_field(self, model_name, name, app_label): @@ -47,20 +45,14 @@ class FieldOperation(Operation): elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields: return True # Check if this operation remotely references the field. - if self.field: - model_tuple = (app_label, model_name_lower) - remote_field = self.field.remote_field - if remote_field: - if (resolve_relation(remote_field.model, app_label, self.model_name_lower) == model_tuple and - (not hasattr(self.field, 'to_fields') or - name in self.field.to_fields or None in self.field.to_fields)): - return True - through = getattr(remote_field, 'through', None) - if (through and resolve_relation(through, app_label, self.model_name_lower) == model_tuple and - (getattr(remote_field, 'through_fields', None) is None or - name in remote_field.through_fields)): - return True - return False + if self.field is None: + return False + return bool(field_references( + (app_label, self.model_name_lower), + self.field, + (app_label, model_name_lower), + name, + )) def reduce(self, operation, app_label): return ( @@ -236,7 +228,9 @@ class AlterField(FieldOperation): # not referenced by a foreign key. delay = ( not field.is_relation and - not is_referenced_by_foreign_key(state, self.model_name_lower, self.field, self.name) + not field_is_referenced( + state, (app_label, self.model_name_lower), (self.name, field), + ) ) state.reload_model(app_label, self.model_name_lower, delay=delay) @@ -305,17 +299,11 @@ class RenameField(FieldOperation): model_state = state.models[app_label, self.model_name_lower] # Rename the field fields = model_state.fields - found = False + found = None for index, (name, field) in enumerate(fields): if not found and name == self.old_name: fields[index] = (self.new_name, field) - found = True - # Delay rendering of relationships if it's not a relational - # field and not referenced by a foreign key. - delay = ( - not field.is_relation and - not is_referenced_by_foreign_key(state, self.model_name_lower, field, self.name) - ) + found = field # Fix from_fields to refer to the new field. from_fields = getattr(field, 'from_fields', None) if from_fields: @@ -323,7 +311,7 @@ class RenameField(FieldOperation): self.new_name if from_field_name == self.old_name else from_field_name for from_field_name in from_fields ]) - if not found: + if found is None: raise FieldDoesNotExist( "%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name) ) @@ -336,23 +324,21 @@ class RenameField(FieldOperation): for together in options[option] ] # Fix to_fields to refer to the new field. - model_tuple = app_label, self.model_name_lower - for (model_app_label, model_name), model_state in state.models.items(): - for index, (name, field) in enumerate(model_state.fields): - remote_field = field.remote_field - if remote_field: - remote_model_tuple = resolve_relation( - remote_field.model, model_app_label, model_name - ) - if remote_model_tuple == model_tuple: - if getattr(remote_field, 'field_name', None) == self.old_name: - remote_field.field_name = self.new_name - to_fields = getattr(field, 'to_fields', None) - if to_fields: - field.to_fields = tuple([ - self.new_name if to_field_name == self.old_name else to_field_name - for to_field_name in to_fields - ]) + delay = True + references = get_references( + state, (app_label, self.model_name_lower), (self.old_name, found), + ) + for *_, field, reference in references: + delay = False + if reference.to: + remote_field, to_fields = reference.to + if getattr(remote_field, 'field_name', None) == self.old_name: + remote_field.field_name = self.new_name + if to_fields: + field.to_fields = tuple([ + self.new_name if to_field_name == self.old_name else to_field_name + for to_field_name in to_fields + ]) state.reload_model(app_label, self.model_name_lower, delay=delay) def database_forwards(self, app_label, schema_editor, from_state, to_state): diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index fa247f56eb..baa902089f 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -7,7 +7,7 @@ from django.utils.functional import cached_property from .fields import ( AddField, AlterField, FieldOperation, RemoveField, RenameField, ) -from .utils import field_references_model, resolve_relation +from .utils import field_references, get_references, resolve_relation def _check_for_duplicates(arg_name, objs): @@ -113,7 +113,7 @@ class CreateModel(ModelOperation): # Check we have no FKs/M2Ms with it for _name, field in self.fields: - if field_references_model((app_label, self.name_lower), field, reference_model_tuple): + if field_references((app_label, self.name_lower), field, reference_model_tuple): return True return False @@ -309,33 +309,19 @@ class RenameModel(ModelOperation): # Repoint all fields pointing to the old model to the new one. old_model_tuple = (app_label, self.old_name_lower) new_remote_model = '%s.%s' % (app_label, self.new_name) - to_reload = [] - for (model_app_label, model_name), model_state in state.models.items(): - model_changed = False - for index, (name, field) in enumerate(model_state.fields): - changed_field = None - remote_field = field.remote_field - if remote_field: - remote_model_tuple = resolve_relation( - remote_field.model, model_app_label, model_name - ) - if remote_model_tuple == old_model_tuple: - changed_field = field.clone() - changed_field.remote_field.model = new_remote_model - through_model = getattr(remote_field, 'through', None) - if through_model: - through_model_tuple = resolve_relation( - through_model, model_app_label, model_name - ) - if through_model_tuple == old_model_tuple: - if changed_field is None: - changed_field = field.clone() - changed_field.remote_field.through = new_remote_model - if changed_field: - model_state.fields[index] = name, changed_field - model_changed = True - if model_changed: - to_reload.append((model_app_label, model_name)) + to_reload = set() + for model_state, index, name, field, reference in get_references(state, old_model_tuple): + changed_field = None + if reference.to: + changed_field = field.clone() + changed_field.remote_field.model = new_remote_model + if reference.through: + if changed_field is None: + changed_field = field.clone() + changed_field.remote_field.through = new_remote_model + if changed_field: + model_state.fields[index] = name, changed_field + to_reload.add((model_state.app_label, model_state.name_lower)) # Reload models related to old model before removing the old model. state.reload_models(to_reload, delay=True) # Remove the old model. diff --git a/django/db/migrations/operations/utils.py b/django/db/migrations/operations/utils.py index 24319fb383..0295d60af9 100644 --- a/django/db/migrations/operations/utils.py +++ b/django/db/migrations/operations/utils.py @@ -1,17 +1,8 @@ +from collections import namedtuple + from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT -def is_referenced_by_foreign_key(state, model_name_lower, field, field_name): - for state_app_label, state_model in state.models: - for _, f in state.models[state_app_label, state_model].fields: - if (f.related_model and - '%s.%s' % (state_app_label, model_name_lower) == f.related_model.lower() and - hasattr(f, 'to_fields')): - if (f.to_fields[0] is None and field.primary_key) or field_name in f.to_fields: - return True - return False - - def resolve_relation(model, app_label=None, model_name=None): """ Turn a model class or model reference string and return a model tuple. @@ -38,13 +29,73 @@ def resolve_relation(model, app_label=None, model_name=None): return model._meta.app_label, model._meta.model_name -def field_references_model(model_tuple, field, reference_model_tuple): - """Return whether or not field references reference_model_tuple.""" +FieldReference = namedtuple('FieldReference', 'to through') + + +def field_references( + model_tuple, + field, + reference_model_tuple, + reference_field_name=None, + reference_field=None, +): + """ + Return either False or a FieldReference if `field` references provided + context. + + False positives can be returned if `reference_field_name` is provided + without `reference_field` because of the introspection limitation it + incurs. This should not be an issue when this function is used to determine + whether or not an optimization can take place. + """ remote_field = field.remote_field - if remote_field: - if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple: - return True - through = getattr(remote_field, 'through', None) - if through and resolve_relation(through, *model_tuple) == reference_model_tuple: - return True - return False + if not remote_field: + return False + references_to = None + references_through = None + if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple: + to_fields = getattr(field, 'to_fields', None) + if ( + reference_field_name is None or + # Unspecified to_field(s). + to_fields is None or + # Reference to primary key. + (None in to_fields and (reference_field is None or reference_field.primary_key)) or + # Reference to field. + reference_field_name in to_fields + ): + references_to = (remote_field, to_fields) + through = getattr(remote_field, 'through', None) + if through and resolve_relation(through, *model_tuple) == reference_model_tuple: + through_fields = remote_field.through_fields + if ( + reference_field_name is None or + # Unspecified through_fields. + through_fields is None or + # Reference to field. + reference_field_name in through_fields + ): + references_through = (remote_field, through_fields) + if not (references_to or references_through): + return False + return FieldReference(references_to, references_through) + + +def get_references(state, model_tuple, field_tuple=()): + """ + Generator of (model_state, index, name, field, reference) referencing + provided context. + + If field_tuple is provided only references to this particular field of + model_tuple will be generated. + """ + for state_model_tuple, model_state in state.models.items(): + for index, (name, field) in enumerate(model_state.fields): + reference = field_references(state_model_tuple, field, model_tuple, *field_tuple) + if reference: + yield model_state, index, name, field, reference + + +def field_is_referenced(state, model_tuple, field_tuple): + """Return whether `field_tuple` is referenced by any state models.""" + return next(get_references(state, model_tuple, field_tuple), None) is not None