Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows to pass config name to search class configuration #2

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
14 changes: 8 additions & 6 deletions tortoise/contrib/postgres/functions.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
from pypika.terms import Function, Term

DEFAULT_TEXT_SEARCH_CONFIG = "pg_catalog.simple"


class ToTsVector(Function): # type: ignore
"""
to to_tsvector function
"""

def __init__(self, field: Term):
super(ToTsVector, self).__init__("TO_TSVECTOR", field)
def __init__(self, field: Term, config_name: str = DEFAULT_TEXT_SEARCH_CONFIG):
super(ToTsVector, self).__init__("TO_TSVECTOR", config_name, field)


class ToTsQuery(Function): # type: ignore
"""
to_tsquery function
"""

def __init__(self, field: Term):
super(ToTsQuery, self).__init__("TO_TSQUERY", field)
def __init__(self, field: Term, config_name: str = DEFAULT_TEXT_SEARCH_CONFIG):
super(ToTsQuery, self).__init__("TO_TSQUERY", config_name, field)


class PlainToTsQuery(Function): # type: ignore
"""
plainto_tsquery function
"""

def __init__(self, field: Term):
super(PlainToTsQuery, self).__init__("PLAINTO_TSQUERY", field)
def __init__(self, field: Term, config_name: str = DEFAULT_TEXT_SEARCH_CONFIG):
super(PlainToTsQuery, self).__init__("PLAINTO_TSQUERY", config_name, field)


class Random(Function): # type: ignore
Expand Down
12 changes: 4 additions & 8 deletions tortoise/contrib/postgres/search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Union

from pypika.enums import Comparator
from pypika.terms import BasicCriterion, Function, Term

from pypika.terms import BasicCriterion, Term, Function
from tortoise.contrib.postgres.functions import ToTsQuery, ToTsVector


Expand All @@ -12,8 +10,6 @@ class Comp(Comparator): # type: ignore

class SearchCriterion(BasicCriterion): # type: ignore
def __init__(self, field: Term, expr: Union[Term, Function]):
if isinstance(expr, Function):
_expr = expr
else:
_expr = ToTsQuery(expr)
super().__init__(Comp.search, ToTsVector(field), _expr)
if not isinstance(expr, Function):
expr = ToTsQuery(expr)
super().__init__(Comp.search, ToTsVector(config_name=expr.args[0].value, field=field), expr)