mirror of
https://github.com/wagtail/wagtail.git
synced 2024-11-29 17:36:49 +01:00
[Postgres search] Improves and tests weight calculus.
This commit is contained in:
parent
ea746ee74f
commit
344b5d4bf7
@ -1,11 +1,10 @@
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
from __future__ import absolute_import, division, unicode_literals
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.core.checks import Error, Tags, register
|
||||
|
||||
from .utils import (
|
||||
BOOSTS_WEIGHTS, WEIGHTS_COUNT, WEIGHTS_VALUES, determine_boosts_weights,
|
||||
get_postgresql_connections)
|
||||
BOOSTS_WEIGHTS, WEIGHTS_VALUES, determine_boosts_weights, get_postgresql_connections)
|
||||
|
||||
|
||||
class PostgresSearchConfig(AppConfig):
|
||||
@ -21,9 +20,6 @@ class PostgresSearchConfig(AppConfig):
|
||||
id='wagtail.contrib.postgres_search.E001')]
|
||||
|
||||
BOOSTS_WEIGHTS.extend(determine_boosts_weights())
|
||||
sorted_boosts_weights = sorted(BOOSTS_WEIGHTS, key=lambda t: t[0])
|
||||
max_weight = sorted_boosts_weights[-1][0]
|
||||
max_weight = BOOSTS_WEIGHTS[0][0]
|
||||
WEIGHTS_VALUES.extend([v / max_weight
|
||||
for v, w in sorted_boosts_weights])
|
||||
for _ in range(WEIGHTS_COUNT - len(WEIGHTS_VALUES)):
|
||||
WEIGHTS_VALUES.insert(0, 0)
|
||||
for v, w in reversed(BOOSTS_WEIGHTS)])
|
||||
|
@ -7,6 +7,8 @@ from django.utils.six import StringIO
|
||||
from wagtail.tests.search.models import SearchTest
|
||||
from wagtail.wagtailsearch.tests.test_backends import BackendTests
|
||||
|
||||
from ..utils import BOOSTS_WEIGHTS, WEIGHTS_VALUES, determine_boosts_weights, get_weight
|
||||
|
||||
|
||||
class TestPostgresSearchBackend(BackendTests, TestCase):
|
||||
backend_path = 'wagtail.contrib.postgres_search.backend'
|
||||
@ -40,3 +42,31 @@ class TestPostgresSearchBackend(BackendTests, TestCase):
|
||||
results = self.backend.search('world', SearchTest)
|
||||
self.assertSetEqual(set(results), {self.testa,
|
||||
self.testd.searchtest_ptr})
|
||||
|
||||
def test_weights(self):
|
||||
self.assertListEqual(BOOSTS_WEIGHTS,
|
||||
[(10, 'A'), (2, 'B'), (0, 'C'), (0, 'D')])
|
||||
self.assertListEqual(WEIGHTS_VALUES, [0, 0, 0.2, 1.0])
|
||||
|
||||
self.assertEqual(get_weight(15), 'A')
|
||||
self.assertEqual(get_weight(10), 'A')
|
||||
self.assertEqual(get_weight(9.9), 'B')
|
||||
self.assertEqual(get_weight(2), 'B')
|
||||
self.assertEqual(get_weight(1.9), 'C')
|
||||
self.assertEqual(get_weight(0), 'C')
|
||||
self.assertEqual(get_weight(-1), 'D')
|
||||
|
||||
self.assertListEqual(determine_boosts_weights([1]),
|
||||
[(1, 'A'), (0, 'B'), (0, 'C'), (0, 'D')])
|
||||
self.assertListEqual(determine_boosts_weights([-1]),
|
||||
[(-1, 'A'), (-1, 'B'), (-1, 'C'), (-1, 'D')])
|
||||
self.assertListEqual(determine_boosts_weights([-1, 1, 2]),
|
||||
[(2, 'A'), (1, 'B'), (-1, 'C'), (-1, 'D')])
|
||||
self.assertListEqual(determine_boosts_weights([0, 1, 2, 3]),
|
||||
[(3, 'A'), (2, 'B'), (1, 'C'), (0, 'D')])
|
||||
self.assertListEqual(determine_boosts_weights([0, 0.25, 0.75, 1, 1.5]),
|
||||
[(1.5, 'A'), (1, 'B'), (0.5, 'C'), (0, 'D')])
|
||||
self.assertListEqual(determine_boosts_weights([0, 1, 2, 3, 4, 5, 6]),
|
||||
[(6, 'A'), (4, 'B'), (2, 'C'), (0, 'D')])
|
||||
self.assertListEqual(determine_boosts_weights([-2, -1, 0, 1, 2, 3, 4]),
|
||||
[(4, 'A'), (2, 'B'), (0, 'C'), (-2, 'D')])
|
||||
|
@ -1,4 +1,4 @@
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
from __future__ import absolute_import, division, unicode_literals
|
||||
|
||||
import operator
|
||||
import re
|
||||
@ -8,6 +8,7 @@ from django.apps import apps
|
||||
from django.db import connections
|
||||
from django.db.models import Q
|
||||
from django.utils.lru_cache import lru_cache
|
||||
from django.utils.six.moves import zip_longest
|
||||
|
||||
from wagtail.wagtailsearch.index import Indexed, RelatedFields, SearchField
|
||||
|
||||
@ -94,25 +95,34 @@ BOOSTS_WEIGHTS = []
|
||||
WEIGHTS_VALUES = []
|
||||
|
||||
|
||||
def determine_boosts_weights():
|
||||
def get_boosts():
|
||||
boosts = set()
|
||||
for model in apps.get_models():
|
||||
if issubclass(model, Indexed):
|
||||
for search_field in get_search_fields(model.get_search_fields()):
|
||||
boost = search_field.boost
|
||||
boosts.add(0 if boost is None else boost)
|
||||
if boost is not None:
|
||||
boosts.add(boost)
|
||||
return boosts
|
||||
|
||||
|
||||
def determine_boosts_weights(boosts=()):
|
||||
if not boosts:
|
||||
boosts = get_boosts()
|
||||
boosts = list(sorted(boosts, reverse=True))
|
||||
min_boost = boosts[-1]
|
||||
if len(boosts) <= WEIGHTS_COUNT:
|
||||
return zip(reversed(sorted(boosts)), WEIGHTS)
|
||||
min_boost = min(boosts)
|
||||
max_boost = max(boosts)
|
||||
boost_step = (max_boost - min_boost) / WEIGHTS_COUNT
|
||||
return [(min_boost + (i * boost_step), weight)
|
||||
for i, weight in zip(range(WEIGHTS_COUNT), WEIGHTS)]
|
||||
return list(zip_longest(boosts, WEIGHTS, fillvalue=min(min_boost, 0)))
|
||||
max_boost = boosts[0]
|
||||
boost_step = (max_boost - min_boost) / (WEIGHTS_COUNT - 1)
|
||||
return [(max_boost - (i * boost_step), weight)
|
||||
for i, weight in enumerate(WEIGHTS)]
|
||||
|
||||
|
||||
def get_weight(boost):
|
||||
if boost is None:
|
||||
boost = 0
|
||||
return WEIGHTS[-1]
|
||||
for max_boost, weight in BOOSTS_WEIGHTS:
|
||||
if boost >= max_boost:
|
||||
return weight
|
||||
return weight
|
||||
|
Loading…
Reference in New Issue
Block a user