diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 54b4e86245..caaeaefa6e 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -319,10 +319,10 @@ class SQLCompiler(object): for name in self.query.distinct_fields: parts = name.split(LOOKUP_SEP) - field, cols, alias, _, _ = self._setup_joins(parts, opts, None) - cols, alias = self._final_join_removal(cols, alias) - for col in cols: - result.append("%s.%s" % (qn(alias), qn2(col))) + _, targets, alias, joins, path, _ = self._setup_joins(parts, opts, None) + targets, alias, _ = self.query.trim_joins(targets, joins, path) + for target in targets: + result.append("%s.%s" % (qn(alias), qn2(target.column))) return result def get_ordering(self): @@ -421,7 +421,7 @@ class SQLCompiler(object): return result, params, group_by def find_ordering_name(self, name, opts, alias=None, default_order='ASC', - already_seen=None): + already_seen=None): """ Returns the table alias (the name might be ambiguous, the alias will not be) and column name for ordering by the given 'name' parameter. @@ -429,11 +429,11 @@ class SQLCompiler(object): """ name, order = get_order_dir(name, default_order) pieces = name.split(LOOKUP_SEP) - field, cols, alias, joins, opts = self._setup_joins(pieces, opts, alias) + field, targets, alias, joins, path, opts = self._setup_joins(pieces, opts, alias) # If we get to this point and the field is a relation to another model, # append the default ordering for that model. - if field.rel and len(joins) > 1 and opts.ordering: + if field.rel and path and opts.ordering: # Firstly, avoid infinite loops. if not already_seen: already_seen = set() @@ -445,10 +445,10 @@ class SQLCompiler(object): results = [] for item in opts.ordering: results.extend(self.find_ordering_name(item, opts, alias, - order, already_seen)) + order, already_seen)) return results - cols, alias = self._final_join_removal(cols, alias) - return [(alias, cols, order)] + targets, alias, _ = self.query.trim_joins(targets, joins, path) + return [(alias, [t.column for t in targets], order)] def _setup_joins(self, pieces, opts, alias): """ @@ -461,13 +461,12 @@ class SQLCompiler(object): """ if not alias: alias = self.query.get_initial_alias() - field, targets, opts, joins, _ = self.query.setup_joins( + field, targets, opts, joins, path = self.query.setup_joins( pieces, opts, alias) # We will later on need to promote those joins that were added to the # query afresh above. joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2] alias = joins[-1] - cols = [target.column for target in targets] if not field.rel: # To avoid inadvertent trimming of a necessary alias, use the # refcount to show that we are referencing a non-relation field on @@ -478,28 +477,7 @@ class SQLCompiler(object): # Ordering or distinct must not affect the returned set, and INNER # JOINS for nullable fields could do this. self.query.promote_joins(joins_to_promote) - return field, cols, alias, joins, opts - - def _final_join_removal(self, cols, alias): - """ - A helper method for get_distinct and get_ordering. This method will - trim extra not-needed joins from the tail of the join chain. - - This is very similar to what is done in trim_joins, but we will - trim LEFT JOINS here. It would be a good idea to consolidate this - method and query.trim_joins(). - """ - if alias: - while 1: - join = self.query.alias_map[alias] - lhs_cols, rhs_cols = zip(*[(lhs_col, rhs_col) for lhs_col, rhs_col in join.join_cols]) - if set(cols) != set(rhs_cols): - break - - cols = [lhs_cols[rhs_cols.index(col)] for col in cols] - self.query.unref_alias(alias) - alias = join.lhs_alias - return cols, alias + return field, targets, alias, joins, path, opts def get_from_clause(self): """ diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 4d9ffb353f..bb4e9eee8f 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -25,7 +25,7 @@ from .models import ( OneToOneCategory, NullableName, ProxyCategory, SingleObject, RelatedObject, ModelA, ModelB, ModelC, ModelD, Responsibility, Job, JobResponsibilities, BaseA, FK1, Identifier, Program, Channel, Page, Paragraph, Chapter, Book, - MyObject, Order, OrderItem) + MyObject, Order, OrderItem, SharedConnection) class BaseQuerysetTest(TestCase): def assertValueQuerysetEqual(self, qs, values): @@ -2977,3 +2977,14 @@ class RelatedLookupTypeTests(TestCase): self.assertQuerysetEqual( ObjectB.objects.filter(objecta__in=[wrong_type]), [ob], lambda x: x) + +class Ticket14056Tests(TestCase): + def test_ticket_14056(self): + s1 = SharedConnection.objects.create(data='s1') + s2 = SharedConnection.objects.create(data='s2') + s3 = SharedConnection.objects.create(data='s3') + PointerA.objects.create(connection=s2) + self.assertQuerysetEqual( + SharedConnection.objects.order_by('pointera__connection', 'pk'), + [s1, s3, s2], lambda x: x + )