From 6398b3a4c9f0c3c27bde9afc3a09b864784544b8 Mon Sep 17 00:00:00 2001 From: Patrick Weizhi Xu Date: Fri, 29 Nov 2024 15:29:49 +0800 Subject: [PATCH] enhance: add search iterator v2 Signed-off-by: Patrick Weizhi Xu --- pymilvus/client/abstract.py | 4 ++ pymilvus/client/constants.py | 4 ++ pymilvus/client/prepare.py | 20 ++++++ pymilvus/orm/collection.py | 38 +++++----- pymilvus/orm/constants.py | 5 ++ pymilvus/orm/iterator.py | 135 ++++++++++++++++++++++++++++------- 6 files changed, 165 insertions(+), 41 deletions(-) diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index a7711749b..aeb4f9f63 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -497,11 +497,15 @@ def __init__( ) nq_thres += topk self._session_ts = session_ts + self._search_iterator_v2_results = res.search_iterator_v2_results super().__init__(data) def get_session_ts(self): return self._session_ts + def get_search_iterator_v2_results_info(self): + return self._search_iterator_v2_results + def get_fields_by_range( self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData] ) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]: diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index efb38aa7a..989a9bf33 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -16,6 +16,10 @@ STRICT_GROUP_SIZE = "strict_group_size" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +ITER_SEARCH_V2_KEY = "search_iter_v2" +ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size" +ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound" +ITER_SEARCH_ID_KEY = "search_iter_id" PAGE_RETAIN_ORDER_FIELD = "page_retain_order" RANKER_TYPE_RRF = "rrf" diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 4bbc95b9a..99611ad84 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -19,6 +19,10 @@ DYNAMIC_FIELD_NAME, GROUP_BY_FIELD, GROUP_SIZE, + ITER_SEARCH_BATCH_SIZE_KEY, + ITER_SEARCH_ID_KEY, + ITER_SEARCH_LAST_BOUND_KEY, + ITER_SEARCH_V2_KEY, ITERATOR_FIELD, PAGE_RETAIN_ORDER_FIELD, RANK_GROUP_SCORER, @@ -941,6 +945,22 @@ def search_requests_with_expr( if is_iterator is not None: search_params[ITERATOR_FIELD] = is_iterator + is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY) + if is_search_iter_v2 is not None: + search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2 + + search_iter_batch_size = kwargs.get(ITER_SEARCH_BATCH_SIZE_KEY) + if search_iter_batch_size is not None: + search_params[ITER_SEARCH_BATCH_SIZE_KEY] = search_iter_batch_size + + search_iter_last_bound = kwargs.get(ITER_SEARCH_LAST_BOUND_KEY) + if search_iter_last_bound is not None: + search_params[ITER_SEARCH_LAST_BOUND_KEY] = search_iter_last_bound + + search_iter_id = kwargs.get(ITER_SEARCH_ID_KEY) + if search_iter_id is not None: + search_params[ITER_SEARCH_ID_KEY] = search_iter_id + group_by_field = kwargs.get(GROUP_BY_FIELD) if group_by_field is not None: search_params[GROUP_BY_FIELD] = group_by_field diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index a31d374cf..37c05b271 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -42,7 +42,7 @@ from .constants import UNLIMITED from .future import MutationFuture, SearchFuture from .index import Index -from .iterator import QueryIterator, SearchIterator +from .iterator import QueryIterator, SearchIterator, SearchIteratorV2 from .mutation import MutationResult from .partition import Partition from .prepare import Prepare @@ -969,26 +969,32 @@ def search_iterator( output_fields: Optional[List[str]] = None, timeout: Optional[float] = None, round_decimal: int = -1, + use_v1: Optional[bool] = False, **kwargs, ): if expr is not None and not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) - return SearchIterator( - connection=self._get_connection(), - collection_name=self._name, - data=data, - ann_field=anns_field, - param=param, - batch_size=batch_size, - limit=limit, - expr=expr, - partition_names=partition_names, - output_fields=output_fields, - timeout=timeout, - round_decimal=round_decimal, - schema=self._schema_dict, + + iterator_params = { + "connection": self._get_connection(), + "collection_name": self._name, + "data": data, + "anns_field": anns_field, + "param": param, + "batch_size": batch_size, + "limit": limit, + "expr": expr, + "partition_names": partition_names, + "output_fields": output_fields, + "timeout": timeout, + "round_decimal": round_decimal, + "schema": self._schema_dict, **kwargs, - ) + } + + if use_v1: + return SearchIterator(**iterator_params) + return SearchIteratorV2(**iterator_params) def query( self, diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index 6862ab75f..8df1ed053 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -48,6 +48,11 @@ REDUCE_STOP_FOR_BEST = "reduce_stop_for_best" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +ITER_SEARCH_V2_KEY = "search_iter_v2" +ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size" +ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound" +ITER_SEARCH_ID_KEY = "search_iter_id" +ITER_SEARCH_TTL_KEY = "search_iter_ttl" PRINT_ITERATOR_CURSOR = "print_iterator_cursor" DEFAULT_MAX_L2_DISTANCE = 99999999.0 DEFAULT_MIN_IP_DISTANCE = -99999999.0 diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 15118523f..f031f4358 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -27,6 +27,11 @@ GUARANTEE_TIMESTAMP, INT64_MAX, IS_PRIMARY, + ITER_SEARCH_BATCH_SIZE_KEY, + ITER_SEARCH_ID_KEY, + ITER_SEARCH_LAST_BOUND_KEY, + ITER_SEARCH_TTL_KEY, + ITER_SEARCH_V2_KEY, ITERATOR_FIELD, ITERATOR_SESSION_CP_FILE, ITERATOR_SESSION_TS_FIELD, @@ -51,7 +56,7 @@ LOGGER.setLevel(logging.INFO) QueryIterator = TypeVar("QueryIterator") SearchIterator = TypeVar("SearchIterator") - +SearchIteratorV2 = TypeVar("SearchIteratorV2") log = logging.getLogger(__name__) @@ -87,6 +92,13 @@ def check_set_flag(obj: Any, flag_name: str, kwargs: Dict[str, Any], key: str): setattr(obj, flag_name, kwargs.get(key, False)) +def check_batch_size(batch_size: int): + if batch_size < 0: + raise ParamError(message="batch size cannot be less than zero") + if batch_size > MAX_BATCH_SIZE: + raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}") + + class QueryIterator: def __init__( self, @@ -192,10 +204,7 @@ def __check_set_reduce_stop_for_best(self): self._kwargs[REDUCE_STOP_FOR_BEST] = "False" def __check_set_batch_size(self, batch_size: int): - if batch_size < 0: - raise ParamError(message="batch size cannot be less than zero") - if batch_size > MAX_BATCH_SIZE: - raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}") + check_batch_size(batch_size) self._kwargs[BATCH_SIZE] = batch_size self._kwargs[MILVUS_LIMIT] = batch_size @@ -432,13 +441,31 @@ def distances(self): return distances +def check_num_queries(data: Union[List, utils.SparseMatrixInputType]): + rows = entity_helper.get_input_num_rows(data) + if rows > 1: + raise ParamError(message="Not support search iteration over multiple vectors at present") + if rows == 0: + raise ParamError(message="vector_data for search cannot be empty") + + +def check_metrics(param: Dict): + if param[METRIC_TYPE] is None or param[METRIC_TYPE] == "": + raise ParamError(message="must specify metrics type for search iterator") + + +def check_offset(kwargs: Dict): + if kwargs.get(OFFSET, 0) != 0: + raise ParamError(message="Not support offset when searching iteration") + + class SearchIterator: def __init__( self, connection: Connections, collection_name: str, data: Union[List, utils.SparseMatrixInputType], - ann_field: str, + anns_field: str, param: Dict, batch_size: Optional[int] = 1000, limit: Optional[int] = UNLIMITED, @@ -450,18 +477,14 @@ def __init__( schema: Optional[CollectionSchema] = None, **kwargs, ) -> SearchIterator: - rows = entity_helper.get_input_num_rows(data) - if rows > 1: - raise ParamError( - message="Not support search iteration over multiple vectors at present" - ) - if rows == 0: - raise ParamError(message="vector_data for search cannot be empty") + check_num_queries(data) + check_metrics(param) + check_offset(kwargs) self._conn = connection self._iterator_params = { "collection_name": collection_name, "data": data, - "ann_field": ann_field, + "anns_field": anns_field, BATCH_SIZE: batch_size, "output_fields": output_fields, "partition_names": partition_names, @@ -478,8 +501,6 @@ def __init__( self._schema = schema self._limit = limit self._returned_count = 0 - self.__check_metrics() - self.__check_offset() self.__check_rm_range_search_parameters() self.__setup__pk_prop() check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR) @@ -561,10 +582,6 @@ def __setup__pk_prop(self): if self._pk_field_name is None or self._pk_field_name == "": raise ParamError(message="schema must contain pk field, broke") - def __check_metrics(self): - if self._param[METRIC_TYPE] is None or self._param[METRIC_TYPE] == "": - raise ParamError(message="must specify metrics type for search iterator") - """we use search && range search to implement search iterator, so range search parameters are disabled to clients""" @@ -587,10 +604,6 @@ def __check_rm_range_search_parameters(self): f"smalled than range_filter, please adjust your parameter" ) - def __check_offset(self): - if self._kwargs.get(OFFSET, 0) != 0: - raise ParamError(message="Not support offset when searching iteration") - def __update_filtered_ids(self, res: SearchPage): if len(res) == 0: return @@ -698,7 +711,7 @@ def __execute_next_search( res = self._conn.search( self._iterator_params["collection_name"], self._iterator_params["data"], - self._iterator_params["ann_field"], + self._iterator_params["anns_field"], next_params, extend_batch_size(self._iterator_params[BATCH_SIZE], next_params, to_extend_batch), next_expr, @@ -784,3 +797,75 @@ def release_cache(self, cache_id: int): NO_CACHE_ID = -1 # Singleton Mode in Python iterator_cache = IteratorCache() + + +class SearchIteratorV2: + def __init__( + self, + connection: Connections, + collection_name: str, + data: Union[List, utils.SparseMatrixInputType], + anns_field: str, + param: Dict, + batch_size: int = 1000, + expr: Optional[str] = None, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + ttl: Optional[int] = None, + round_decimal: int = -1, + **kwargs, + ) -> SearchIteratorV2: + check_num_queries(data) + check_metrics(param) + check_offset(kwargs) + check_batch_size(batch_size) + + # delete limit from incoming for compatibility + if MILVUS_LIMIT in kwargs: + del kwargs[MILVUS_LIMIT] + + self._conn = connection + self._params = { + "collection_name": collection_name, + "data": data, + "anns_field": anns_field, + "param": deepcopy(param), + "limit": batch_size, + "expression": expr, + "partition_names": partition_names, + "output_fields": output_fields, + "round_decimal": round_decimal, + "timeout": timeout, + ITERATOR_FIELD: True, + ITER_SEARCH_V2_KEY: True, + ITER_SEARCH_BATCH_SIZE_KEY: batch_size, + ITER_SEARCH_TTL_KEY: ttl, + GUARANTEE_TIMESTAMP: 0, + **kwargs, + } + + def next(self): + res = self._conn.search(**self._params) + iter_info = res.get_search_iterator_v2_results_info() + self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound + + # patch token and guarantee timestamp for the first next() call + if ITER_SEARCH_ID_KEY not in self._params: + self._params[ITER_SEARCH_ID_KEY] = iter_info.token + if self._params[GUARANTEE_TIMESTAMP] <= 0: + if res.get_session_ts() > 0: + self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts() + else: + log.warning( + "failed to set up mvccTs from milvus server, use client-side ts instead" + ) + self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts() + + # return SearchPage for compability + if len(res) > 0: + return SearchPage(res[0]) + return SearchPage(None) + + def close(self): + pass