diff --git a/wagtail/contrib/postgres_search/apps.py b/wagtail/contrib/postgres_search/apps.py index 8d0b93fbce..530fdb72a1 100644 --- a/wagtail/contrib/postgres_search/apps.py +++ b/wagtail/contrib/postgres_search/apps.py @@ -1,11 +1,10 @@ -from __future__ import absolute_import, unicode_literals +from __future__ import absolute_import, division, unicode_literals from django.apps import AppConfig from django.core.checks import Error, Tags, register from .utils import ( - BOOSTS_WEIGHTS, WEIGHTS_COUNT, WEIGHTS_VALUES, determine_boosts_weights, - get_postgresql_connections) + BOOSTS_WEIGHTS, WEIGHTS_VALUES, determine_boosts_weights, get_postgresql_connections) class PostgresSearchConfig(AppConfig): @@ -21,9 +20,6 @@ class PostgresSearchConfig(AppConfig): id='wagtail.contrib.postgres_search.E001')] BOOSTS_WEIGHTS.extend(determine_boosts_weights()) - sorted_boosts_weights = sorted(BOOSTS_WEIGHTS, key=lambda t: t[0]) - max_weight = sorted_boosts_weights[-1][0] + max_weight = BOOSTS_WEIGHTS[0][0] WEIGHTS_VALUES.extend([v / max_weight - for v, w in sorted_boosts_weights]) - for _ in range(WEIGHTS_COUNT - len(WEIGHTS_VALUES)): - WEIGHTS_VALUES.insert(0, 0) + for v, w in reversed(BOOSTS_WEIGHTS)]) diff --git a/wagtail/contrib/postgres_search/tests/test_backend.py b/wagtail/contrib/postgres_search/tests/test_backend.py index 9ce885f0c5..df1a745ede 100644 --- a/wagtail/contrib/postgres_search/tests/test_backend.py +++ b/wagtail/contrib/postgres_search/tests/test_backend.py @@ -7,6 +7,8 @@ from django.utils.six import StringIO from wagtail.tests.search.models import SearchTest from wagtail.wagtailsearch.tests.test_backends import BackendTests +from ..utils import BOOSTS_WEIGHTS, WEIGHTS_VALUES, determine_boosts_weights, get_weight + class TestPostgresSearchBackend(BackendTests, TestCase): backend_path = 'wagtail.contrib.postgres_search.backend' @@ -40,3 +42,31 @@ class TestPostgresSearchBackend(BackendTests, TestCase): results = self.backend.search('world', SearchTest) self.assertSetEqual(set(results), {self.testa, self.testd.searchtest_ptr}) + + def test_weights(self): + self.assertListEqual(BOOSTS_WEIGHTS, + [(10, 'A'), (2, 'B'), (0, 'C'), (0, 'D')]) + self.assertListEqual(WEIGHTS_VALUES, [0, 0, 0.2, 1.0]) + + self.assertEqual(get_weight(15), 'A') + self.assertEqual(get_weight(10), 'A') + self.assertEqual(get_weight(9.9), 'B') + self.assertEqual(get_weight(2), 'B') + self.assertEqual(get_weight(1.9), 'C') + self.assertEqual(get_weight(0), 'C') + self.assertEqual(get_weight(-1), 'D') + + self.assertListEqual(determine_boosts_weights([1]), + [(1, 'A'), (0, 'B'), (0, 'C'), (0, 'D')]) + self.assertListEqual(determine_boosts_weights([-1]), + [(-1, 'A'), (-1, 'B'), (-1, 'C'), (-1, 'D')]) + self.assertListEqual(determine_boosts_weights([-1, 1, 2]), + [(2, 'A'), (1, 'B'), (-1, 'C'), (-1, 'D')]) + self.assertListEqual(determine_boosts_weights([0, 1, 2, 3]), + [(3, 'A'), (2, 'B'), (1, 'C'), (0, 'D')]) + self.assertListEqual(determine_boosts_weights([0, 0.25, 0.75, 1, 1.5]), + [(1.5, 'A'), (1, 'B'), (0.5, 'C'), (0, 'D')]) + self.assertListEqual(determine_boosts_weights([0, 1, 2, 3, 4, 5, 6]), + [(6, 'A'), (4, 'B'), (2, 'C'), (0, 'D')]) + self.assertListEqual(determine_boosts_weights([-2, -1, 0, 1, 2, 3, 4]), + [(4, 'A'), (2, 'B'), (0, 'C'), (-2, 'D')]) diff --git a/wagtail/contrib/postgres_search/utils.py b/wagtail/contrib/postgres_search/utils.py index 0d8393c0d4..68b500e135 100644 --- a/wagtail/contrib/postgres_search/utils.py +++ b/wagtail/contrib/postgres_search/utils.py @@ -1,4 +1,4 @@ -from __future__ import absolute_import, unicode_literals +from __future__ import absolute_import, division, unicode_literals import operator import re @@ -8,6 +8,7 @@ from django.apps import apps from django.db import connections from django.db.models import Q from django.utils.lru_cache import lru_cache +from django.utils.six.moves import zip_longest from wagtail.wagtailsearch.index import Indexed, RelatedFields, SearchField @@ -94,25 +95,34 @@ BOOSTS_WEIGHTS = [] WEIGHTS_VALUES = [] -def determine_boosts_weights(): +def get_boosts(): boosts = set() for model in apps.get_models(): if issubclass(model, Indexed): for search_field in get_search_fields(model.get_search_fields()): boost = search_field.boost - boosts.add(0 if boost is None else boost) + if boost is not None: + boosts.add(boost) + return boosts + + +def determine_boosts_weights(boosts=()): + if not boosts: + boosts = get_boosts() + boosts = list(sorted(boosts, reverse=True)) + min_boost = boosts[-1] if len(boosts) <= WEIGHTS_COUNT: - return zip(reversed(sorted(boosts)), WEIGHTS) - min_boost = min(boosts) - max_boost = max(boosts) - boost_step = (max_boost - min_boost) / WEIGHTS_COUNT - return [(min_boost + (i * boost_step), weight) - for i, weight in zip(range(WEIGHTS_COUNT), WEIGHTS)] + return list(zip_longest(boosts, WEIGHTS, fillvalue=min(min_boost, 0))) + max_boost = boosts[0] + boost_step = (max_boost - min_boost) / (WEIGHTS_COUNT - 1) + return [(max_boost - (i * boost_step), weight) + for i, weight in enumerate(WEIGHTS)] def get_weight(boost): if boost is None: - boost = 0 + return WEIGHTS[-1] for max_boost, weight in BOOSTS_WEIGHTS: if boost >= max_boost: return weight + return weight