Skip to content

Commit

Permalink
feat: add smart_words and boost_phrases
Browse files Browse the repository at this point in the history
and tests
  • Loading branch information
alexgarel committed Aug 27, 2024
1 parent 91a0929 commit 47cfacf
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 51 deletions.
46 changes: 41 additions & 5 deletions app/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,32 @@ class SearchParameters(BaseModel):
),
] = None

boost_phrase: Annotated[
bool,
Query(
description="""This enables an heuristic that will favor,
matching terms that are consecutive.
Technically, if you have a query with the two words `whole milk`
it will boost entries with `"whole milk"` exact match.
The boost factor is defined by `match_phrase_boost` value in Configuration
Note, that it only make sense if you use best match sorting.
So in any other case it is ignored."""
),
]

smart_words: Annotated[
bool,
Query(
description="""This enables an heuristic that helps users search what they mean.
If the type `Whole Milk labels:vegan`, it will match (whole OR milk) AND labels:vegan`.
Also consider the `boost_phrase` parameter with this one"""
),
]

langs: Annotated[
list[str],
Query(
Expand Down Expand Up @@ -214,8 +240,9 @@ class SearchParameters(BaseModel):
Query(
description=textwrap.dedent(
"""
Field name to use to sort results, the field should exist
and be sortable. If it is not provided, results are sorted by descending relevance score.
Field name to use to sort results, the field should exist and be sortable.
If it is not provided, results are sorted by descending relevance score.
(aka best match)
If you put a minus before the name, the results will be sorted by descending order.
Expand Down Expand Up @@ -315,7 +342,9 @@ def sort_by_is_field_or_script(self):
is_field = sort_by in index_config.fields
# TODO: verify field type is compatible with sorting
if not (self.sort_by is None or is_field or self.uses_sort_script):
raise ValueError("`sort_by` must be a valid field name or script name")
raise ValueError(
"`sort_by` must be a valid field name or script name or None"
)
return self

@model_validator(mode="after")
Expand Down Expand Up @@ -419,12 +448,15 @@ def _annotation_new_type(type_, annotation):
return Annotated[type_, *annotation.__metadata__]


# types for search parameters for GET
# types and annotations for search parameters for GET,
# created from POST search parameters
SEARCH_PARAMS_ANN = get_type_hints(SearchParameters, include_extras=True)


class GetSearchParamsTypes:
q = SEARCH_PARAMS_ANN["q"]
boost_phrase = SEARCH_PARAMS_ANN["boost_phrase"]
smart_words = SEARCH_PARAMS_ANN["smart_words"]
langs = _annotation_new_type(str, SEARCH_PARAMS_ANN["langs"])
page_size = SEARCH_PARAMS_ANN["page_size"]
page = SEARCH_PARAMS_ANN["page"]
Expand Down Expand Up @@ -453,7 +485,11 @@ class FetcherStatus(Enum):


class FetcherResult(BaseModel):
"""Result for a document fecher"""
"""Result for a document fetcher
This is also used by pre-processors
who have the opportunity to discard an entry
"""

status: FetcherStatus
document: JSONType | None
4 changes: 4 additions & 0 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def parse_charts_get(charts_params: str):
@app.get("/search")
def search_get(
q: GetSearchParamsTypes.q = None,
boost_phrase: GetSearchParamsTypes.boost_phrase = False,
smart_words: GetSearchParamsTypes.smart_words = False,
langs: GetSearchParamsTypes.langs = None,
page_size: GetSearchParamsTypes.page_size = 10,
page: GetSearchParamsTypes.page = 1,
Expand All @@ -147,6 +149,8 @@ def search_get(
try:
search_parameters = SearchParameters(
q=q,
boost_phrase=boost_phrase,
smart_words=smart_words,
langs=langs_list,
page_size=page_size,
page=page,
Expand Down
3 changes: 2 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,8 @@ class IndexConfig(BaseModel):
description=cd_(
"""How much we boost exact matches on individual fields
This only makes sense when using "best match" order.
This only makes sense when using
"boost_phrase" request parameters and "best match" order.
"""
)
),
Expand Down
73 changes: 30 additions & 43 deletions app/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
from .es_scripts import get_script_id
from .indexing import generate_index_object
from .postprocessing import BaseResultProcessor
from .query_transformers import (
LanguageSuffixTransformer,
PhraseBoostTransformer,
SmartUnknownOperationResolver,
)
from .utils import get_logger, str_utils

logger = get_logger(__name__)
Expand Down Expand Up @@ -180,49 +185,6 @@ def create_aggregation_clauses(
return clauses


class LanguageSuffixTransformer(luqum.visitor.TreeTransformer):

def __init__(self, lang_fields=set[str], langs=list[str], **kwargs):
# we need to track parents to get full field name
super().__init__(track_parents=True, track_new_parents=False, **kwargs)
self.langs = langs
self.lang_fields = lang_fields

def visit_search_field(self, node, context):
"""As we reach a search_field,
if it's one that have a lang,
we replace single expression with a OR on sub-language fields
"""
# FIXME: verify again the way luqum work on this side !
field_name = node.name
# add eventual parents
prefix = ".".join(
node.name
for node in context["parents"]
if isinstance(node, tree.SearchField)
)
if prefix:
field_name = f"{prefix}.{field_name}"
# is it a lang dependant field
if field_name in self.lang_fields:
# create a new expression for each languages
new_nodes = []
for lang in self.langs:
# note: we don't have to care about having searchfield in children
# because only complete field_name would match a self.lang_fields
new_node = self.generic_visit(node)
# add language prefix
new_node.name = f"{new_node.name}.{lang}"
new_nodes.append(new_node)
if len(new_nodes) > 1:
yield tree.OrOperation(*new_nodes)
else:
yield from new_nodes
else:
# default
yield from self.generic_visit(node)


def add_languages_suffix(
analysis: QueryAnalysis, langs: list[str], config: IndexConfig
) -> QueryAnalysis:
Expand All @@ -239,6 +201,27 @@ def add_languages_suffix(
return analysis


def add_smart_words(analysis: QueryAnalysis) -> QueryAnalysis:
"""Add smart words heuristic
see SearchParameters.smart_words
"""
if analysis.luqum_tree is None:
return analysis
transformer = SmartUnknownOperationResolver()
analysis.luqum_tree = transformer.visit(analysis.luqum_tree)
return analysis


def boost_phrases(analysis: QueryAnalysis, boost: float | str) -> QueryAnalysis:
"""Boost all phrases in the query"""
if analysis.luqum_tree is None:
return analysis
transformer = PhraseBoostTransformer(boost=boost)
analysis.luqum_tree = transformer.visit(analysis.luqum_tree)
return analysis


def build_search_query(
params: SearchParameters,
es_query_builder: ElasticsearchQueryBuilder,
Expand All @@ -252,6 +235,10 @@ def build_search_query(
"""
analysis = parse_query(params.q)
analysis = compute_facets_filters(analysis)
if params.smart_words:
analysis = add_smart_words(analysis)
if params.boost_phrase and params.sort_by is None:
analysis = boost_phrases(analysis, params.index_config.match_phrase_boost)
# add languages for localized fields
analysis = add_languages_suffix(analysis, params.langs, params.index_config)

Expand Down
168 changes: 168 additions & 0 deletions app/query_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import luqum.visitor
from luqum import tree
from luqum.utils import UnknownOperationResolver


class LanguageSuffixTransformer(luqum.visitor.TreeTransformer):
"""This transformer adds a language suffix to lang_fields fields,
for any languages in langs (the languages we want to query on).
That is `field1:something` will become
`field1:en:something OR field1:fr:something`
"""

def __init__(self, lang_fields=set[str], langs=list[str], **kwargs):
# we need to track parents to get full field name
super().__init__(track_parents=True, track_new_parents=False, **kwargs)
self.langs = langs
self.lang_fields = lang_fields

def visit_search_field(self, node, context):
"""As we reach a search_field,
if it's one that have a lang,
we replace single expression with a OR on sub-language fields
"""
# FIXME: verify again the way luqum work on this side !
field_name = node.name
# add eventual parents
prefix = ".".join(
node.name
for node in context["parents"]
if isinstance(node, tree.SearchField)
)
if prefix:
field_name = f"{prefix}.{field_name}"
# is it a lang dependant field
if field_name in self.lang_fields:
# create a new expression for each languages
new_nodes = []
for lang in self.langs:
# note: we don't have to care about having searchfield in children
# because only complete field_name would match a self.lang_fields
new_node = self.generic_visit(node)
# add language prefix
new_node.name = f"{new_node.name}.{lang}"
new_nodes.append(new_node)
if len(new_nodes) > 1:
yield tree.OrOperation(*new_nodes)
else:
yield from new_nodes
else:
# default
yield from self.generic_visit(node)


def get_consecutive_words(
node: tree.BoolOperation,
) -> list[list[tuple[int, tree.Word]]]:
"""Return a list of list of consecutive words,
with their index, in a bool operation
"""
consecutive: list[list[tuple[int, tree.Word]]] = [[]]
for i, child in enumerate(node.children):
if isinstance(child, tree.Word):
# append to last list
consecutive[-1].append((i, child))
else:
# we have discontinuity
if len(consecutive[-1]) == 1:
# one term alone is not enough, clear the list
consecutive[-1] = []
elif consecutive[-1]:
# create a new list
consecutive.append([])
# remove last list if empty or only one element
if len(consecutive[-1]) <= 1:
consecutive.pop()
return consecutive


class PhraseBoostTransformer(luqum.visitor.TreeTransformer):
"""This transformer boosts terms that are consecutive
and might be found in a query
For example if we have `Whole OR Milk OR Cream`
we will boost items containing "Whole Milk Cream"
We also only apply it to terms that are not for a specified field
"""

def __init__(self, boost=float, **kwargs):
super().__init__(track_parents=True, track_new_parents=False, **kwargs)
self.boost = boost

def _get_consecutive_words(self, node):
return [[word for _, word in words] for words in get_consecutive_words(node)]

def _phrase_from_words(self, words):
expr = " ".join(word.value for word in words)
expr = f'"{expr}"'
phrase = tree.Phrase(expr)
return tree.Boost(phrase, force=self.boost, head=" ", tail=" ")

def visit_or_operation(self, node, context):
"""As we find an OR operation try to boost consecutive word terms"""
# get the or operation with cloned children
(new_node,) = list(super().generic_visit(node, context))
has_search_field = any(
isinstance(p, tree.SearchField) for p in context.get("parents", [])
)
if not has_search_field:
# we are in an expression with no field specified, transform
consecutive = self._get_consecutive_words(new_node)
if consecutive:
# create new match phrase w terms from consecutive words
new_terms = [self._phrase_from_words(words) for words in consecutive]
# head / tail problem
new_terms[-1].tail = new_node.children[-1].tail
new_node.children[-1].tail = " "
# add operands
new_node.children += tuple(new_terms)
yield new_node


class SmartUnknownOperationResolver(UnknownOperationResolver):
"""A complex unknown operation resolver that fits what users might intend
It replace UnknownOperation by a AND operation,
but if consecutive words are found it will try to group them in a OR operation
"""

def _get_consecutive_words(self, node):
return get_consecutive_words(node)

def _words_or_operation(self, words):
# transfer head and tail
head = words[0].head
tail = words[-1].tail
words[0].head = ""
words[-1].tail = ""
operation = tree.Group(tree.OrOperation(*words), head=head, tail=tail)
return operation

def visit_unknown_operation(self, node, context):
# create the node as intended, this might be AND or OR operation
(new_node,) = list(super().visit_unknown_operation(node, context))
# if it's AND operation
if isinstance(new_node, tree.AndOperation):
# group consecutive terms in OROperations
consecutive = self._get_consecutive_words(new_node)
if consecutive:
# change first word by the OR operation
index_to_change = {
words[0][0]: self._words_or_operation([word[1] for word in words])
for words in consecutive
}
# remove other words that are part of the expression
index_to_remove = set(
word[0] for words in consecutive for word in words[1:]
)
new_children = []
for i, child in enumerate(new_node.children):
if i in index_to_change:
new_children.append(index_to_change[i])
elif i not in index_to_remove:
new_children.append(child)
# substitute children
new_node.children = new_children
yield new_node
Loading

0 comments on commit 47cfacf

Please sign in to comment.