0
0
mirror of https://github.com/wagtail/wagtail.git synced 2024-12-01 11:41:20 +01:00

Add chunking/iterator() support to PageQuerySet.specific() (#8271)

* Add chunking support to SpecificIterable

* Add tests

* Remove the redundant specific_iterator() function
This commit is contained in:
Andy Babic 2022-04-12 13:08:00 +01:00 committed by GitHub
parent 1d73aa2cec
commit b46403969e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 199 additions and 89 deletions

View File

@ -1,6 +1,7 @@
import posixpath
import warnings
from collections import defaultdict
from typing import Any, Dict, Iterable, Tuple
from django.apps import apps
from django.contrib.contenttypes.models import ContentType
@ -480,94 +481,106 @@ class PageQuerySet(SearchableQuerySetMixin, TreeQuerySet):
)
def specific_iterator(qs, defer=False):
"""
This efficiently iterates all the specific pages in a queryset, using
the minimum number of queries.
This should be called from ``PageQuerySet.specific``
"""
from wagtail.models import Page
annotation_aliases = qs.query.annotations.keys()
values = qs.values("pk", "content_type", *annotation_aliases)
annotations_by_pk = defaultdict(list)
if annotation_aliases:
# Extract annotation results keyed by pk so we can reapply to fetched pages.
for data in values:
annotations_by_pk[data["pk"]] = {
k: v for k, v in data.items() if k in annotation_aliases
}
pks_and_types = [[v["pk"], v["content_type"]] for v in values]
pks_by_type = defaultdict(list)
for pk, content_type in pks_and_types:
pks_by_type[content_type].append(pk)
# Content types are cached by ID, so this will not run any queries.
content_types = {pk: ContentType.objects.get_for_id(pk) for _, pk in pks_and_types}
# Get the specific instances of all pages, one model class at a time.
pages_by_type = {}
missing_pks = []
for content_type, pks in pks_by_type.items():
# look up model class for this content type, falling back on the original
# model (i.e. Page) if the more specific one is missing
model = content_types[content_type].model_class() or qs.model
pages = model.objects.filter(pk__in=pks)
if defer:
# Defer all specific fields
fields = [
field.attname for field in Page._meta.get_fields() if field.concrete
]
pages = pages.only(*fields)
elif qs._defer_streamfields:
pages = pages.defer_streamfields()
pages_for_type = {page.pk: page for page in pages}
pages_by_type[content_type] = pages_for_type
missing_pks.extend(pk for pk in pks if pk not in pages_for_type)
# Fetch generic pages to supplement missing items
if missing_pks:
generic_pages = (
Page.objects.filter(pk__in=missing_pks)
.select_related("content_type")
.in_bulk()
)
warnings.warn(
"Specific versions of the following pages could not be found. "
"This is most likely because a database migration has removed "
"the relevant table or record since the page was created:\n{}".format(
[
{"id": p.id, "title": p.title, "type": p.content_type}
for p in generic_pages.values()
]
),
category=RuntimeWarning,
)
else:
generic_pages = {}
# Yield all pages in the order they occurred in the original query.
for pk, content_type in pks_and_types:
try:
page = pages_by_type[content_type][pk]
except KeyError:
page = generic_pages[pk]
if annotation_aliases:
# Reapply annotations before returning
for annotation, value in annotations_by_pk.get(page.pk, {}).items():
setattr(page, annotation, value)
yield page
class SpecificIterable(BaseIterable):
def __iter__(self):
return specific_iterator(self.queryset)
"""
Identify and return all specific pages in a queryset, and return them
in the same order, with any annotations intact.
"""
from wagtail.models import Page
qs = self.queryset
annotation_aliases = qs.query.annotations.keys()
values_qs = qs.values("pk", "content_type", *annotation_aliases)
# Gather pages in batches to reduce peak memory usage
for values in self._get_chunks(values_qs):
annotations_by_pk = defaultdict(list)
if annotation_aliases:
# Extract annotation results keyed by pk so we can reapply to fetched pages.
for data in values:
annotations_by_pk[data["pk"]] = {
k: v for k, v in data.items() if k in annotation_aliases
}
pks_and_types = [[v["pk"], v["content_type"]] for v in values]
pks_by_type = defaultdict(list)
for pk, content_type in pks_and_types:
pks_by_type[content_type].append(pk)
# Content types are cached by ID, so this will not run any queries.
content_types = {
pk: ContentType.objects.get_for_id(pk) for _, pk in pks_and_types
}
# Get the specific instances of all pages, one model class at a time.
pages_by_type = {}
missing_pks = []
for content_type, pks in pks_by_type.items():
# look up model class for this content type, falling back on the original
# model (i.e. Page) if the more specific one is missing
model = content_types[content_type].model_class() or qs.model
pages = model.objects.filter(pk__in=pks)
if qs._defer_streamfields:
pages = pages.defer_streamfields()
pages_for_type = {page.pk: page for page in pages}
pages_by_type[content_type] = pages_for_type
missing_pks.extend(pk for pk in pks if pk not in pages_for_type)
# Fetch generic pages to supplement missing items
if missing_pks:
generic_pages = (
Page.objects.filter(pk__in=missing_pks)
.select_related("content_type")
.in_bulk()
)
warnings.warn(
"Specific versions of the following pages could not be found. "
"This is most likely because a database migration has removed "
"the relevant table or record since the page was created:\n{}".format(
[
{"id": p.id, "title": p.title, "type": p.content_type}
for p in generic_pages.values()
]
),
category=RuntimeWarning,
)
else:
generic_pages = {}
# Yield all pages in the order they occurred in the original query.
for pk, content_type in pks_and_types:
try:
page = pages_by_type[content_type][pk]
except KeyError:
page = generic_pages[pk]
if annotation_aliases:
# Reapply annotations before returning
for annotation, value in annotations_by_pk.get(page.pk, {}).items():
setattr(page, annotation, value)
yield page
def _get_chunks(self, queryset) -> Iterable[Tuple[Dict[str, Any]]]:
if not self.chunked_fetch:
# The entire result will be stored in memory, so there is no
# benefit to splitting the result
yield tuple(queryset)
else:
# Iterate through the queryset, returning the rows in manageable
# chunks for self.__iter__() to fetch full pages for
current_chunk = []
for r in queryset.iterator(self.chunk_size):
current_chunk.append(r)
if len(current_chunk) == self.chunk_size:
yield tuple(current_chunk)
current_chunk.clear()
# Return any left-overs
if current_chunk:
yield tuple(current_chunk)
class DeferredSpecificIterable(ModelIterable):

View File

@ -722,6 +722,12 @@ class TestSpecificQuery(TestCase):
fixtures = ["test_specific.json"]
def setUp(self):
self.live_pages = Page.objects.live().specific()
self.live_pages_with_annotations = (
Page.objects.live().specific().annotate(count=Count("pk"))
)
def test_specific(self):
root = Page.objects.get(url_path="/home/")
@ -797,12 +803,12 @@ class TestSpecificQuery(TestCase):
def test_specific_query_with_annotations_performs_no_additional_queries(self):
with self.assertNumQueries(5):
pages = list(Page.objects.live().specific())
pages = list(self.live_pages)
self.assertEqual(len(pages), 7)
with self.assertNumQueries(5):
pages = list(Page.objects.live().specific().annotate(count=Count("pk")))
pages = list(self.live_pages_with_annotations)
self.assertEqual(len(pages), 7)
@ -878,7 +884,7 @@ class TestSpecificQuery(TestCase):
# 5928 - PageQuerySet.specific should gracefully handle pages whose ContentType
# row in the specific table no longer exists
# Trick specific_iterator into always looking for EventPages
# Trick SpecificIteraterable.__init__() into always looking for EventPages
with mock.patch(
"wagtail.query.ContentType.objects.get_for_id",
return_value=ContentType.objects.get_for_model(EventPage),
@ -943,6 +949,97 @@ class TestSpecificQuery(TestCase):
# <StreamPage: stream page>
pages[-1].body
def test_specific_query_with_iterator(self):
queryset = self.live_pages_with_annotations
# set benchmark without iterator()
with self.assertNumQueries(5):
benchmark_result = list(queryset.all())
self.assertEqual(len(benchmark_result), 7)
# the default chunk size for iterator() is much higher than 7, so all
# items should fetched with the same number of queries
with self.assertNumQueries(5):
result_1 = list(queryset.all().iterator())
self.assertEqual(result_1, benchmark_result)
# specifying a smaller chunk_size for iterator() should force the
# results to be processed in multiple batches, increasing the number
# of queries
with self.assertNumQueries(7):
result_2 = list(queryset.all().iterator(chunk_size=5))
self.assertEqual(result_2, benchmark_result)
# repeat with a smaller chunk size for good measure
with self.assertNumQueries(6):
# The number of queries is actually lower, because
# each chunk contains fewer 'unique' page types
result_3 = list(queryset.all().iterator(chunk_size=2))
self.assertEqual(result_3, benchmark_result)
def test_bottom_sliced_specific_query_with_iterator(self):
queryset = self.live_pages_with_annotations[2:]
# set benchmark without iterator()
with self.assertNumQueries(4):
benchmark_result = list(queryset.all())
self.assertEqual(len(benchmark_result), 5)
# using plain iterator() with the same sliced queryset should produce
# an identical result with the same number of queries
with self.assertNumQueries(4):
result_1 = list(queryset.all().iterator())
self.assertEqual(result_1, benchmark_result)
# if the iterator() chunk size is smaller than the slice,
# SpecificIterable should still apply chunking whilst maintaining
# the slice starting point
with self.assertNumQueries(6):
result_2 = list(queryset.all().iterator(chunk_size=1))
self.assertEqual(result_2, benchmark_result)
def test_top_sliced_specific_query_with_iterator(self):
queryset = self.live_pages_with_annotations[:6]
# set benchmark without iterator()
with self.assertNumQueries(5):
benchmark_result = list(queryset.all())
self.assertEqual(len(benchmark_result), 6)
# using plain iterator() with the same sliced queryset should produce
# an identical result with the same number of queries
with self.assertNumQueries(5):
result_1 = list(queryset.all().iterator())
self.assertEqual(result_1, benchmark_result)
# if the iterator() chunk size is smaller than the slice,
# SpecificIterable should still apply chunking whilst maintaining
# the slice end point
with self.assertNumQueries(7):
result_2 = list(queryset.all().iterator(chunk_size=1))
self.assertEqual(result_2, benchmark_result)
def test_top_and_bottom_sliced_specific_query_with_iterator(self):
queryset = self.live_pages_with_annotations[2:6]
# set benchmark without iterator()
with self.assertNumQueries(4):
benchmark_result = list(queryset.all())
self.assertEqual(len(benchmark_result), 4)
# using plain iterator() with the same sliced queryset should produce
# an identical result with the same number of queries
with self.assertNumQueries(4):
result_1 = list(queryset.all().iterator())
self.assertEqual(result_1, benchmark_result)
# if the iterator() chunk size is smaller than the slice,
# SpecificIterable should still apply chunking whilst maintaining
# the slice's start and end point
with self.assertNumQueries(5):
result_2 = list(queryset.all().iterator(chunk_size=3))
self.assertEqual(result_2, benchmark_result)
class TestFirstCommonAncestor(TestCase):
"""