mirror of
https://github.com/wagtail/wagtail.git
synced 2024-12-01 11:41:20 +01:00
Factor out repeated code
This commit is contained in:
parent
688b025187
commit
0bcf5a951e
@ -439,6 +439,30 @@ class ElasticsearchSearchResults(BaseSearchResults):
|
||||
|
||||
return body
|
||||
|
||||
def _get_results_from_hits(self, hits):
|
||||
"""
|
||||
Yields Django model instances from a page of hits returned by Elasticsearch
|
||||
"""
|
||||
# Get pks from results
|
||||
pks = [hit['fields']['pk'][0] for hit in hits]
|
||||
scores = {str(hit['fields']['pk'][0]): hit['_score'] for hit in hits}
|
||||
|
||||
# Initialise results dictionary
|
||||
results = {str(pk): None for pk in pks}
|
||||
|
||||
# Find objects in database and add them to dict
|
||||
for obj in self.query.queryset.filter(pk__in=pks):
|
||||
results[str(obj.pk)] = obj
|
||||
|
||||
if self._score_field:
|
||||
setattr(obj, self._score_field, scores.get(str(obj.pk)))
|
||||
|
||||
# Yield results in order given by Elasticsearch
|
||||
for pk in pks:
|
||||
result = results[str(pk)]
|
||||
if result:
|
||||
yield result
|
||||
|
||||
def _do_search(self):
|
||||
PAGE_SIZE = 100
|
||||
|
||||
@ -468,34 +492,19 @@ class ElasticsearchSearchResults(BaseSearchResults):
|
||||
|
||||
while True:
|
||||
hits = page['hits']['hits']
|
||||
|
||||
if len(hits) == 0:
|
||||
break
|
||||
|
||||
# Get pks from results
|
||||
pks = [hit['fields']['pk'][0] for hit in hits]
|
||||
scores = {str(hit['fields']['pk'][0]): hit['_score'] for hit in hits}
|
||||
# Get results
|
||||
for result in self._get_results_from_hits(hits):
|
||||
if limit is not None and limit == 0:
|
||||
break
|
||||
|
||||
# Initialise results dictionary
|
||||
results = {str(pk): None for pk in pks}
|
||||
yield result
|
||||
|
||||
# Find objects in database and add them to dict
|
||||
for obj in self.query.queryset.filter(pk__in=pks):
|
||||
results[str(obj.pk)] = obj
|
||||
|
||||
if self._score_field:
|
||||
setattr(obj, self._score_field, scores.get(str(obj.pk)))
|
||||
|
||||
# Yield results in order given by Elasticsearch
|
||||
for pk in pks:
|
||||
result = results[str(pk)]
|
||||
if result:
|
||||
yield result
|
||||
|
||||
if limit is not None:
|
||||
limit -= 1
|
||||
|
||||
if limit == 0:
|
||||
break
|
||||
if limit is not None:
|
||||
limit -= 1
|
||||
|
||||
if limit is not None and limit == 0:
|
||||
break
|
||||
@ -517,25 +526,9 @@ class ElasticsearchSearchResults(BaseSearchResults):
|
||||
# Send to Elasticsearch
|
||||
hits = self.backend.es.search(**params)['hits']['hits']
|
||||
|
||||
# Get pks from results
|
||||
pks = [hit['fields']['pk'][0] for hit in hits]
|
||||
scores = {str(hit['fields']['pk'][0]): hit['_score'] for hit in hits}
|
||||
|
||||
# Initialise results dictionary
|
||||
results = {str(pk): None for pk in pks}
|
||||
|
||||
# Find objects in database and add them to dict
|
||||
for obj in self.query.queryset.filter(pk__in=pks):
|
||||
results[str(obj.pk)] = obj
|
||||
|
||||
if self._score_field:
|
||||
setattr(obj, self._score_field, scores.get(str(obj.pk)))
|
||||
|
||||
# Yield results in order given by Elasticsearch
|
||||
for pk in pks:
|
||||
result = results[str(pk)]
|
||||
if result:
|
||||
yield result
|
||||
# Get results
|
||||
for result in self._get_results_from_hits(hits):
|
||||
yield result
|
||||
|
||||
def _do_count(self):
|
||||
# Get count
|
||||
|
Loading…
Reference in New Issue
Block a user