0
0
mirror of https://github.com/django/django.git synced 2024-11-25 07:59:34 +01:00

Refs #373 -- Updated TupleIsNull lookup to check if any is NULL rather than all.

Regression in 1eac690d25.
This commit is contained in:
Bendeguz Csirmaz 2024-09-20 10:03:47 +02:00 committed by Sarah Boyce
parent 1857b6663b
commit c2c7dbb2f8
3 changed files with 39 additions and 15 deletions

View File

@ -57,18 +57,25 @@ class TupleExact(TupleLookupMixin, Exact):
return root.as_sql(compiler, connection)
class TupleIsNull(IsNull):
class TupleIsNull(TupleLookupMixin, IsNull):
def get_prep_lookup(self):
rhs = self.rhs
if isinstance(rhs, (tuple, list)) and len(rhs) == 1:
rhs = rhs[0]
if isinstance(rhs, bool):
return rhs
raise ValueError(
"The QuerySet value for an isnull lookup must be True or False."
)
def as_sql(self, compiler, connection):
# e.g.: (a, b, c) is None as SQL:
# WHERE a IS NULL AND b IS NULL AND c IS NULL
vals = self.rhs
if isinstance(vals, bool):
vals = [vals] * len(self.lhs)
cols = self.lhs.get_cols()
lookups = [IsNull(col, val) for col, val in zip(cols, vals)]
root = WhereNode(lookups, connector=AND)
# WHERE a IS NULL OR b IS NULL OR c IS NULL
# e.g.: (a, b, c) is not None as SQL:
# WHERE a IS NOT NULL AND b IS NOT NULL AND c IS NOT NULL
rhs = self.rhs
lookups = [IsNull(col, rhs) for col in self.lhs]
root = WhereNode(lookups, connector=OR if rhs else AND)
return root.as_sql(compiler, connection)

View File

@ -49,7 +49,7 @@ class Group(models.Model):
class Membership(models.Model):
# Table Column Fields
membership_country = models.ForeignKey(Country, models.CASCADE)
membership_country = models.ForeignKey(Country, models.CASCADE, null=True)
date_joined = models.DateTimeField(default=datetime.datetime.now)
invite_reason = models.CharField(max_length=64, null=True)
person_id = models.IntegerField()

View File

@ -516,18 +516,35 @@ class MultiColumnFKTests(TestCase):
def test_isnull_lookup(self):
m1 = Membership.objects.create(
membership_country=self.usa, person=self.bob, group_id=None
person_id=self.bob.id,
membership_country_id=self.usa.id,
group_id=None,
)
m2 = Membership.objects.create(
membership_country=self.usa, person=self.bob, group=self.cia
person_id=self.jim.id,
membership_country_id=None,
group_id=self.cia.id,
)
m3 = Membership.objects.create(
person_id=self.jane.id,
membership_country_id=None,
group_id=None,
)
m4 = Membership.objects.create(
person_id=self.george.id,
membership_country_id=self.soviet_union.id,
group_id=self.kgb.id,
)
for member in [m1, m2, m3]:
with self.assertRaises(Membership.group.RelatedObjectDoesNotExist):
getattr(member, "group")
self.assertSequenceEqual(
Membership.objects.filter(group__isnull=True),
[m1],
[m1, m2, m3],
)
self.assertSequenceEqual(
Membership.objects.filter(group__isnull=False),
[m2],
[m4],
)