From 1b754d638dc5682265b5b4d1fcf5a94d0312113e Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Sat, 3 Jun 2023 15:54:22 +0200 Subject: [PATCH] Fixed #34629 -- Added filtering support to GIS aggregates. --- AUTHORS | 1 + django/contrib/gis/db/models/aggregates.py | 4 +- docs/ref/contrib/gis/geoquerysets.txt | 32 ++++- docs/releases/5.0.txt | 3 + tests/gis_tests/geo3d/tests.py | 12 +- .../relatedapp/fixtures/initial.json | 52 +++++++- tests/gis_tests/relatedapp/tests.py | 113 +++++++++++++++++- 7 files changed, 206 insertions(+), 11 deletions(-) diff --git a/AUTHORS b/AUTHORS index d5e3af2310..2148c64322 100644 --- a/AUTHORS +++ b/AUTHORS @@ -756,6 +756,7 @@ answer newbie questions, and generally made Django that much better: oggy Oliver Beattie Oliver Rutherfurd + Olivier Le Thanh Duong Olivier Sels Olivier Tabone Orestis Markou diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index 8493882696..7ba2e756e7 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -53,8 +53,8 @@ class GeoAggregate(Aggregate): self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False ): c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) - for expr in c.get_source_expressions(): - if not hasattr(expr.field, "geom_type"): + for field in c.get_source_fields(): + if not hasattr(field, "geom_type"): raise ValueError( "Geospatial aggregates only allowed on geometry fields." ) diff --git a/docs/ref/contrib/gis/geoquerysets.txt b/docs/ref/contrib/gis/geoquerysets.txt index 4d4c0c7de9..1c98b6f896 100644 --- a/docs/ref/contrib/gis/geoquerysets.txt +++ b/docs/ref/contrib/gis/geoquerysets.txt @@ -839,6 +839,8 @@ Oracle ``SDO_WITHIN_DISTANCE(poly, geom, 5)`` SpatiaLite ``PtDistWithin(poly, geom, 5)`` ========== ====================================== +.. _gis-aggregation-functions: + Aggregate Functions ------------------- @@ -868,7 +870,7 @@ Example: ``Collect`` ~~~~~~~~~~~ -.. class:: Collect(geo_field) +.. class:: Collect(geo_field, filter=None) *Availability*: `PostGIS `__, SpatiaLite @@ -879,10 +881,14 @@ aggregate, except it can be several orders of magnitude faster than performing a union because it rolls up geometries into a collection or multi object, not caring about dissolving boundaries. +.. versionchanged:: 5.0 + + Support for using the ``filter`` argument was added. + ``Extent`` ~~~~~~~~~~ -.. class:: Extent(geo_field) +.. class:: Extent(geo_field, filter=None) *Availability*: `PostGIS `__, Oracle, SpatiaLite @@ -898,10 +904,14 @@ Example: >>> print(qs["poly__extent"]) (-96.8016128540039, 29.7633724212646, -95.3631439208984, 32.782058715820) +.. versionchanged:: 5.0 + + Support for using the ``filter`` argument was added. + ``Extent3D`` ~~~~~~~~~~~~ -.. class:: Extent3D(geo_field) +.. class:: Extent3D(geo_field, filter=None) *Availability*: `PostGIS `__ @@ -917,10 +927,14 @@ Example: >>> print(qs["poly__extent3d"]) (-96.8016128540039, 29.7633724212646, 0, -95.3631439208984, 32.782058715820, 0) +.. versionchanged:: 5.0 + + Support for using the ``filter`` argument was added. + ``MakeLine`` ~~~~~~~~~~~~ -.. class:: MakeLine(geo_field) +.. class:: MakeLine(geo_field, filter=None) *Availability*: `PostGIS `__, SpatiaLite @@ -936,10 +950,14 @@ Example: >>> print(qs["poly__makeline"]) LINESTRING (-95.3631510000000020 29.7633739999999989, -96.8016109999999941 32.7820570000000018) +.. versionchanged:: 5.0 + + Support for using the ``filter`` argument was added. + ``Union`` ~~~~~~~~~ -.. class:: Union(geo_field) +.. class:: Union(geo_field, filter=None) *Availability*: `PostGIS `__, Oracle, SpatiaLite @@ -963,6 +981,10 @@ Example: ... Union(poly) ... ) # A more sensible approach. +.. versionchanged:: 5.0 + + Support for using the ``filter`` argument was added. + .. rubric:: Footnotes .. [#fnde9im] *See* `OpenGIS Simple Feature Specification For SQL `_, at Ch. 2.1.13.2, p. 2-13 (The Dimensionally Extended Nine-Intersection Model). .. [#fnsdorelate] *See* `SDO_RELATE documentation ` now support the ``filter`` + argument. + :mod:`django.contrib.messages` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/gis_tests/geo3d/tests.py b/tests/gis_tests/geo3d/tests.py index 2e99d741ef..b37deabb46 100644 --- a/tests/gis_tests/geo3d/tests.py +++ b/tests/gis_tests/geo3d/tests.py @@ -1,7 +1,7 @@ import os import re -from django.contrib.gis.db.models import Extent3D, Union +from django.contrib.gis.db.models import Extent3D, Q, Union from django.contrib.gis.db.models.functions import ( AsGeoJSON, AsKML, @@ -244,6 +244,16 @@ class Geo3DTest(Geo3DLoadingHelper, TestCase): City3D.objects.none().aggregate(Extent3D("point"))["point__extent3d"] ) + @skipUnlessDBFeature("supports_3d_functions") + def test_extent3d_filter(self): + self._load_city_data() + extent3d = City3D.objects.aggregate( + ll_cities=Extent3D("point", filter=Q(name__contains="ll")) + )["ll_cities"] + ref_extent3d = (-96.801611, -41.315268, 14.0, 174.783117, 32.782057, 147.0) + for ref_val, ext_val in zip(ref_extent3d, extent3d): + self.assertAlmostEqual(ref_val, ext_val, 6) + @skipUnlessDBFeature("supports_3d_functions") class Geo3DFunctionsTests(FuncTestMixin, Geo3DLoadingHelper, TestCase): diff --git a/tests/gis_tests/relatedapp/fixtures/initial.json b/tests/gis_tests/relatedapp/fixtures/initial.json index e744aae03a..4adf9ef854 100644 --- a/tests/gis_tests/relatedapp/fixtures/initial.json +++ b/tests/gis_tests/relatedapp/fixtures/initial.json @@ -135,5 +135,53 @@ "title": "Patry on Copyright", "author": 2 } - } -] \ No newline at end of file + }, + { + "model": "relatedapp.parcel", + "pk": 1, + "fields": { + "name": "Aurora Parcel Alpha", + "city": 1, + "center1": "POINT (1.7128 -2.0060)", + "center2": "POINT (3.7128 -5.0060)", + "border1": "POLYGON((0 0, 5 5, 12 12, 0 0))", + "border2": "POLYGON((0 0, 5 5, 8 8, 0 0))" + } + }, + { + "model": "relatedapp.parcel", + "pk": 2, + "fields": { + "name": "Aurora Parcel Beta", + "city": 1, + "center1": "POINT (4.7128 5.0060)", + "center2": "POINT (12.75 10.05)", + "border1": "POLYGON((10 10, 15 15, 22 22, 10 10))", + "border2": "POLYGON((10 10, 15 15, 22 22, 10 10))" + } + }, + { + "model": "relatedapp.parcel", + "pk": 3, + "fields": { + "name": "Aurora Parcel Ignore", + "city": 1, + "center1": "POINT (9.7128 12.0060)", + "center2": "POINT (1.7128 -2.0060)", + "border1": "POLYGON ((24 23, 25 25, 32 32, 24 23))", + "border2": "POLYGON ((24 23, 25 25, 32 32, 24 23))" + } + }, + { + "model": "relatedapp.parcel", + "pk": 4, + "fields": { + "name": "Roswell Parcel Ignore", + "city": 2, + "center1": "POINT (-9.7128 -12.0060)", + "center2": "POINT (-1.7128 2.0060)", + "border1": "POLYGON ((30 30, 35 35, 42 32, 30 30))", + "border2": "POLYGON ((30 30, 35 35, 42 32, 30 30))" + } + } +] diff --git a/tests/gis_tests/relatedapp/tests.py b/tests/gis_tests/relatedapp/tests.py index e11a410afc..7fcd59b84f 100644 --- a/tests/gis_tests/relatedapp/tests.py +++ b/tests/gis_tests/relatedapp/tests.py @@ -1,4 +1,5 @@ -from django.contrib.gis.db.models import Collect, Count, Extent, F, Union +from django.contrib.gis.db.models import Collect, Count, Extent, F, MakeLine, Q, Union +from django.contrib.gis.db.models.functions import Centroid from django.contrib.gis.geos import GEOSGeometry, MultiPoint, Point from django.db import NotSupportedError, connection from django.test import TestCase, skipUnlessDBFeature @@ -304,6 +305,116 @@ class RelatedGeoModelTest(TestCase): self.assertEqual(4, len(coll)) self.assertTrue(ref_geom.equals(coll)) + @skipUnlessDBFeature("supports_collect_aggr") + def test_collect_filter(self): + qs = City.objects.annotate( + parcel_center=Collect( + "parcel__center1", + filter=~Q(parcel__name__icontains="ignore"), + ), + parcel_center_nonexistent=Collect( + "parcel__center1", + filter=Q(parcel__name__icontains="nonexistent"), + ), + parcel_center_single=Collect( + "parcel__center1", + filter=Q(parcel__name__contains="Alpha"), + ), + ) + city = qs.get(name="Aurora") + self.assertEqual( + city.parcel_center.wkt, "MULTIPOINT (1.7128 -2.006, 4.7128 5.006)" + ) + self.assertIsNone(city.parcel_center_nonexistent) + self.assertIn( + city.parcel_center_single.wkt, + [ + "MULTIPOINT (1.7128 -2.006)", + "POINT (1.7128 -2.006)", # SpatiaLite collapse to POINT. + ], + ) + + @skipUnlessDBFeature("has_Centroid_function", "supports_collect_aggr") + def test_centroid_collect_filter(self): + qs = City.objects.annotate( + parcel_centroid=Centroid( + Collect( + "parcel__center1", + filter=~Q(parcel__name__icontains="ignore"), + ) + ) + ) + city = qs.get(name="Aurora") + self.assertEqual(city.parcel_centroid.wkt, "POINT (3.2128 1.5)") + + @skipUnlessDBFeature("supports_make_line_aggr") + def test_make_line_filter(self): + qs = City.objects.annotate( + parcel_line=MakeLine( + "parcel__center1", + filter=~Q(parcel__name__icontains="ignore"), + ), + parcel_line_nonexistent=MakeLine( + "parcel__center1", + filter=Q(parcel__name__icontains="nonexistent"), + ), + ) + city = qs.get(name="Aurora") + self.assertIn( + city.parcel_line.wkt, + # The default ordering is flaky, so check both. + [ + "LINESTRING (1.7128 -2.006, 4.7128 5.006)", + "LINESTRING (4.7128 5.006, 1.7128 -2.006)", + ], + ) + self.assertIsNone(city.parcel_line_nonexistent) + + @skipUnlessDBFeature("supports_extent_aggr") + def test_extent_filter(self): + qs = City.objects.annotate( + parcel_border=Extent( + "parcel__border1", + filter=~Q(parcel__name__icontains="ignore"), + ), + parcel_border_nonexistent=Extent( + "parcel__border1", + filter=Q(parcel__name__icontains="nonexistent"), + ), + parcel_border_no_filter=Extent("parcel__border1"), + ) + city = qs.get(name="Aurora") + self.assertEqual(city.parcel_border, (0.0, 0.0, 22.0, 22.0)) + self.assertIsNone(city.parcel_border_nonexistent) + self.assertEqual(city.parcel_border_no_filter, (0.0, 0.0, 32.0, 32.0)) + + @skipUnlessDBFeature("supports_union_aggr") + def test_union_filter(self): + qs = City.objects.annotate( + parcel_point_union=Union( + "parcel__center2", + filter=~Q(parcel__name__icontains="ignore"), + ), + parcel_point_nonexistent=Union( + "parcel__center2", + filter=Q(parcel__name__icontains="nonexistent"), + ), + parcel_point_union_single=Union( + "parcel__center2", + filter=Q(parcel__name__contains="Alpha"), + ), + ) + city = qs.get(name="Aurora") + self.assertIn( + city.parcel_point_union.wkt, + [ + "MULTIPOINT (12.75 10.05, 3.7128 -5.006)", + "MULTIPOINT (3.7128 -5.006, 12.75 10.05)", + ], + ) + self.assertIsNone(city.parcel_point_nonexistent) + self.assertEqual(city.parcel_point_union_single.wkt, "POINT (3.7128 -5.006)") + def test15_invalid_select_related(self): """ select_related on the related name manager of a unique FK.