Skip to content

Commit

Permalink
Fix sort expression (#2271)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Python SDK impacted, Need to update PyPI

---------

Signed-off-by: jinhai <[email protected]>
Signed-off-by: Jin Hai <[email protected]>
  • Loading branch information
JinHai-CN authored Nov 20, 2024
1 parent 6f5ec59 commit 9b5ef10
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 70 deletions.
16 changes: 8 additions & 8 deletions python/infinity_embedded/local_infinity/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pyarrow import Table
from sqlglot import condition, maybe_parse

from infinity_embedded.common import VEC, SparseVector, InfinityException
from infinity_embedded.common import VEC, SparseVector, InfinityException, SortType
from infinity_embedded.embedded_infinity_ext import *
from infinity_embedded.local_infinity.types import logic_type_to_dtype, make_match_tensor_expr
from infinity_embedded.local_infinity.utils import traverse_conditions, parse_expr
Expand Down Expand Up @@ -477,7 +477,7 @@ def highlight(self, columns: Optional[list]) -> InfinityLocalQueryBuilder:
self._highlight = highlight_list
return self

def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityLocalQueryBuilder:
def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]) -> InfinityLocalQueryBuilder:
sort_list: List[WrapOrderByExpr] = []
for order_by_expr in order_by_expr_list:
if isinstance(order_by_expr[0], str):
Expand All @@ -491,7 +491,7 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL
parsed_expr = WrapParsedExpr(ParsedExprType.kColumn)
parsed_expr.column_expr = column_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc)
sort_list.append(order_by_expr)
case "_row_id":
func_expr = WrapFunctionExpr()
Expand All @@ -502,7 +502,7 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc)
sort_list.append(order_by_expr)
case "_score":
func_expr = WrapFunctionExpr()
Expand All @@ -513,7 +513,7 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc)
sort_list.append(order_by_expr)
case "_similarity":
func_expr = WrapFunctionExpr()
Expand All @@ -524,7 +524,7 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc)
sort_list.append(order_by_expr)
case "_distance":
func_expr = WrapFunctionExpr()
Expand All @@ -535,11 +535,11 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc)
sort_list.append(order_by_expr)
case _:
parsed_expr = parse_expr(maybe_parse(order_by_expr[0]))
order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc)
sort_list.append(order_by_expr)

self._sort = sort_list
Expand Down
8 changes: 2 additions & 6 deletions python/infinity_embedded/local_infinity/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,10 @@ def offset(self, offset: Optional[int]):
def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]):
for order_by_expr in order_by_expr_list:
if len(order_by_expr) != 2:
raise InfinityException(ErrorCode.INVALID_PARAMETER,
raise InfinityException(ErrorCode.INVALID_PARAMETER_VALUE,
"order_by_expr_list must be a list of [column_name, sort_type]")
if order_by_expr[1] not in [SortType.Asc, SortType.Desc]:
raise InfinityException(ErrorCode.INVALID_PARAMETER, "sort_type must be SortType.Asc or SortType.Desc")
if order_by_expr[1] == SortType.Asc:
order_by_expr[1] = True
else:
order_by_expr[1] = False
raise InfinityException(ErrorCode.INVALID_PARAMETER_VALUE, "sort_type must be SortType.Asc or SortType.Desc")
self.query_builder.sort(order_by_expr_list)
return self

Expand Down
108 changes: 58 additions & 50 deletions python/infinity_sdk/infinity/remote_thrift/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyarrow import Table
from sqlglot import condition, maybe_parse

from infinity.common import VEC, SparseVector, InfinityException
from infinity.common import VEC, SparseVector, InfinityException, SortType
from infinity.errors import ErrorCode
from infinity.remote_thrift.infinity_thrift_rpc.ttypes import *
from infinity.remote_thrift.types import (
Expand All @@ -39,15 +39,15 @@

class Query(ABC):
def __init__(
self,
columns: Optional[List[ParsedExpr]],
highlight: Optional[List[ParsedExpr]],
search: Optional[SearchExpr],
filter: Optional[ParsedExpr],
groupby: Optional[List[ParsedExpr]],
limit: Optional[ParsedExpr],
offset: Optional[ParsedExpr],
sort: Optional[List[OrderByExpr]],
self,
columns: Optional[List[ParsedExpr]],
highlight: Optional[List[ParsedExpr]],
search: Optional[SearchExpr],
filter: Optional[ParsedExpr],
groupby: Optional[List[ParsedExpr]],
limit: Optional[ParsedExpr],
offset: Optional[ParsedExpr],
sort: Optional[List[OrderByExpr]],
):
self.columns = columns
self.highlight = highlight
Expand All @@ -61,16 +61,16 @@ def __init__(

class ExplainQuery(Query):
def __init__(
self,
columns: Optional[List[ParsedExpr]],
highlight: Optional[List[ParsedExpr]],
search: Optional[SearchExpr],
filter: Optional[ParsedExpr],
groupby: Optional[List[ParsedExpr]],
limit: Optional[ParsedExpr],
offset: Optional[ParsedExpr],
sort: Optional[List[OrderByExpr]],
explain_type: Optional[ExplainType],
self,
columns: Optional[List[ParsedExpr]],
highlight: Optional[List[ParsedExpr]],
search: Optional[SearchExpr],
filter: Optional[ParsedExpr],
groupby: Optional[List[ParsedExpr]],
limit: Optional[ParsedExpr],
offset: Optional[ParsedExpr],
sort: Optional[List[OrderByExpr]],
explain_type: Optional[ExplainType],
):
super().__init__(columns, highlight, search, filter, groupby, limit, offset, sort)
self.explain_type = explain_type
Expand Down Expand Up @@ -99,13 +99,13 @@ def reset(self):
self._sort = None

def match_dense(
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int,
knn_params: {} = None,
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int,
knn_params: {} = None,
) -> InfinityThriftQueryBuilder:
if self._search is None:
self._search = SearchExpr()
Expand Down Expand Up @@ -134,7 +134,8 @@ def match_dense(
if embedding_data_type == "bit":
if len(embedding_data) % 8 != 0:
raise InfinityException(
ErrorCode.INVALID_EMBEDDING_DATA_TYPE, f"Embeddings with data bit must have dimension of times of 8!"
ErrorCode.INVALID_EMBEDDING_DATA_TYPE,
f"Embeddings with data bit must have dimension of times of 8!"
)
else:
new_embedding_data = []
Expand Down Expand Up @@ -186,7 +187,8 @@ def match_dense(
elem_type = ElementType.ElementBFloat16
data.bf16_array_value = embedding_data
else:
raise InfinityException(ErrorCode.INVALID_EMBEDDING_DATA_TYPE, f"Invalid embedding {embedding_data[0]} type")
raise InfinityException(ErrorCode.INVALID_EMBEDDING_DATA_TYPE,
f"Invalid embedding {embedding_data[0]} type")

dist_type = KnnDistanceType.L2
if distance_type == "l2":
Expand Down Expand Up @@ -223,12 +225,12 @@ def match_dense(
return self

def match_sparse(
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
metric_type: str,
topn: int,
opt_params: Optional[dict] = None,
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
metric_type: str,
topn: int,
opt_params: Optional[dict] = None,
) -> InfinityThriftQueryBuilder:
if self._search is None:
self._search = SearchExpr()
Expand All @@ -243,7 +245,7 @@ def match_sparse(
return self

def match_text(
self, fields: str, matching_text: str, topn: int, extra_options: Optional[dict]
self, fields: str, matching_text: str, topn: int, extra_options: Optional[dict]
) -> InfinityThriftQueryBuilder:
if self._search is None:
self._search = SearchExpr()
Expand All @@ -262,12 +264,12 @@ def match_text(
return self

def match_tensor(
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: Optional[dict] = None,
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: Optional[dict] = None,
) -> InfinityThriftQueryBuilder:
if self._search is None:
self._search = SearchExpr()
Expand Down Expand Up @@ -382,7 +384,7 @@ def highlight(self, columns: Optional[list]) -> InfinityThriftQueryBuilder:
self._highlight = highlight_list
return self

def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityThriftQueryBuilder:
def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]) -> InfinityThriftQueryBuilder:
sort_list: List[OrderByExpr] = []
for order_by_expr in order_by_expr_list:
if isinstance(order_by_expr[0], str):
Expand All @@ -393,35 +395,41 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityT
column_expr = ColumnExpr(star=True, column_name=[])
expr_type = ParsedExprType(column_expr=column_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
order_by_flag: bool = order_by_expr[1] == SortType.Asc
order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag)
sort_list.append(order_by_expr)
case "_row_id":
func_expr = FunctionExpr(function_name="row_id", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
order_by_flag: bool = order_by_expr[1] == SortType.Asc
order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag)
sort_list.append(order_by_expr)
case "_score":
func_expr = FunctionExpr(function_name="score", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
order_by_flag: bool = order_by_expr[1] == SortType.Asc
order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag)
sort_list.append(order_by_expr)
case "_similarity":
func_expr = FunctionExpr(function_name="similarity", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
order_by_flag: bool = order_by_expr[1] == SortType.Asc
order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag)
sort_list.append(order_by_expr)
case "_distance":
func_expr = FunctionExpr(function_name="distance", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
order_by_flag: bool = order_by_expr[1] == SortType.Asc
order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag)
sort_list.append(order_by_expr)
case _:
parsed_expr = parse_expr(maybe_parse(order_by_expr[0]))
sort_list.append(OrderByExpr(expr = parsed_expr, asc = order_by_expr[1]))
order_by_flag: bool = order_by_expr[1] == SortType.Asc
sort_list.append(OrderByExpr(expr=parsed_expr, asc=order_by_flag))

self._sort = sort_list
return self
Expand Down Expand Up @@ -463,7 +471,7 @@ def explain(self, explain_type=ExplainType.Physical) -> Any:
groupby=self._groupby,
limit=self._limit,
offset=self._offset,
sort = self._sort,
sort=self._sort,
explain_type=explain_type,
)
return self._table._explain_query(query)
8 changes: 2 additions & 6 deletions python/infinity_sdk/infinity/remote_thrift/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,10 @@ def offset(self, offset: Optional[int]):
def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]):
for order_by_expr in order_by_expr_list:
if len(order_by_expr) != 2:
raise InfinityException(ErrorCode.INVALID_PARAMETER,
raise InfinityException(ErrorCode.INVALID_PARAMETER_VALUE,
"order_by_expr_list must be a list of [column_name, sort_type]")
if order_by_expr[1] not in [SortType.Asc, SortType.Desc]:
raise InfinityException(ErrorCode.INVALID_PARAMETER, "sort_type must be SortType.Asc or SortType.Desc")
if order_by_expr[1] == SortType.Asc:
order_by_expr[1] = True
else:
order_by_expr[1] = False
raise InfinityException(ErrorCode.INVALID_PARAMETER_VALUE, "sort_type must be SortType.Asc or SortType.Desc")
self.query_builder.sort(order_by_expr_list)
return self

Expand Down

0 comments on commit 9b5ef10

Please sign in to comment.