From b81e974e9ea16bd693b194a728f77fb825ec8e54 Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Mon, 29 May 2023 21:59:22 -0700 Subject: [PATCH] Fixed #34604 -- Corrected fallback SQL for n-ary logical XOR. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit An n-ary logical XOR Q(…) ^ Q(…) ^ … ^ Q(…) should evaluate to true when an odd number of its operands evaluate to true, not when exactly one operand evaluates to true. --- django/db/models/sql/where.py | 7 ++++++- docs/ref/models/querysets.txt | 15 ++++++++++++--- docs/releases/5.0.txt | 5 +++++ tests/xor_lookups/tests.py | 21 +++++++++++++++++++++ 4 files changed, 44 insertions(+), 4 deletions(-) diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index aaab1730b7..2f23a2932c 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -6,6 +6,7 @@ from functools import reduce from django.core.exceptions import EmptyResultSet, FullResultSet from django.db.models.expressions import Case, When +from django.db.models.functions import Mod from django.db.models.lookups import Exact from django.utils import tree from django.utils.functional import cached_property @@ -129,12 +130,16 @@ class WhereNode(tree.Node): # Convert if the database doesn't support XOR: # a XOR b XOR c XOR ... # to: - # (a OR b OR c OR ...) AND (a + b + c + ...) == 1 + # (a OR b OR c OR ...) AND MOD(a + b + c + ..., 2) == 1 + # The result of an n-ary XOR is true when an odd number of operands + # are true. lhs = self.__class__(self.children, OR) rhs_sum = reduce( operator.add, (Case(When(c, then=1), default=0) for c in self.children), ) + if len(self.children) > 2: + rhs_sum = Mod(rhs_sum, 2) rhs = Exact(1, rhs_sum) return self.__class__([lhs, rhs], AND, self.negated).as_sql( compiler, connection diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 1d684607f1..640818d2a5 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -2021,7 +2021,8 @@ may be generated. XOR (``^``) ~~~~~~~~~~~ -Combines two ``QuerySet``\s using the SQL ``XOR`` operator. +Combines two ``QuerySet``\s using the SQL ``XOR`` operator. A ``XOR`` +expression matches rows that are matched by an odd number of operands. The following are equivalent:: @@ -2044,13 +2045,21 @@ SQL equivalent: .. code-block:: sql (x OR y OR ... OR z) AND - 1=( + 1=MOD( (CASE WHEN x THEN 1 ELSE 0 END) + (CASE WHEN y THEN 1 ELSE 0 END) + ... - (CASE WHEN z THEN 1 ELSE 0 END) + + (CASE WHEN z THEN 1 ELSE 0 END), + 2 ) + .. versionchanged:: 5.0 + + In older versions, on databases without native support for the SQL + ``XOR`` operator, ``XOR`` returned rows that were matched by exactly + one operand. The previous behavior was not consistent with MySQL, + MariaDB, and Python behavior. + Methods that do not return ``QuerySet``\s ----------------------------------------- diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 98bd5d2a9f..fb446fcd7a 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -424,6 +424,11 @@ Miscellaneous a page. Having two ``

`` elements was confusing and the site header wasn't helpful as it is repeated on all pages. +* On databases without native support for the SQL ``XOR`` operator, ``^`` as + the exclusive or (``XOR``) operator now returns rows that are matched by an + odd number of operands rather than exactly one operand. This is consistent + with the behavior of MySQL, MariaDB, and Python. + .. _deprecated-features-5.0: Features deprecated in 5.0 diff --git a/tests/xor_lookups/tests.py b/tests/xor_lookups/tests.py index a9cdf9cb31..d58d16cf11 100644 --- a/tests/xor_lookups/tests.py +++ b/tests/xor_lookups/tests.py @@ -19,6 +19,27 @@ class XorLookupsTests(TestCase): self.numbers[:3] + self.numbers[8:], ) + def test_filter_multiple(self): + qs = Number.objects.filter( + Q(num__gte=1) + ^ Q(num__gte=3) + ^ Q(num__gte=5) + ^ Q(num__gte=7) + ^ Q(num__gte=9) + ) + self.assertCountEqual( + qs, + self.numbers[1:3] + self.numbers[5:7] + self.numbers[9:], + ) + self.assertCountEqual( + qs.values_list("num", flat=True), + [ + i + for i in range(10) + if (i >= 1) ^ (i >= 3) ^ (i >= 5) ^ (i >= 7) ^ (i >= 9) + ], + ) + def test_filter_negated(self): self.assertCountEqual( Number.objects.filter(Q(num__lte=7) ^ ~Q(num__lt=3)),