Skip to content

Commit

Permalink
enhance: support milvus-client iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPresent-Han committed Dec 17, 2024
1 parent 712e9b6 commit df3625d
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 0 deletions.
111 changes: 111 additions & 0 deletions examples/hybrid_search_groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
AnnSearchRequest, RRFRanker, WeightedRanker,
)
import secrets

def generate_random_hex_string(length):
return secrets.token_hex(length // 2)

collection_name = 'test_group_by_' + generate_random_hex_string(24)
clean_exist = True
prepare_data = True

fmt = "\n=== {:30} ===\n"
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim = 3000, 8

print(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

if clean_exist and utility.has_collection(collection_name):
utility.drop_collection(collection_name)

fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim),
FieldSchema(name="embeddings2", dtype=DataType.FLOAT_VECTOR, dim=dim),
FieldSchema(name="string", dtype=DataType.VARCHAR, max_length=512),
]

schema = CollectionSchema(fields)

print(fmt.format(f"Create collection `{collection_name}`"))
collection = Collection(collection_name, schema, consistency_level="Strong", num_shards = 4)

print(fmt.format("Start inserting entities"))
rng = np.random.default_rng(seed=19530)

batch_num = 1
for i in range(batch_num):
entities = [
# provide the pk field because `auto_id` is set to False
[str(i) for i in range(num_entities*i, num_entities*(i+1))],
rng.random(num_entities).tolist(), # field random, only supports list
rng.random((num_entities, dim)), # field embeddings, supports numpy.ndarray and list
rng.random((num_entities, dim)), # field embeddings2, supports numpy.ndarray and list
[str(i % 33) for i in range(num_entities)],
]
insert_result = collection.insert(entities)
collection.flush()


print(f"Number of entities in Milvus: {collection.num_entities}") # check the num_entities

print(fmt.format("Start Creating index IVF_FLAT"))
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}

collection.create_index("embeddings", index)
collection.create_index("embeddings2", index)

print(fmt.format("Start loading"))
collection.load()

field_names = ["embeddings", "embeddings2"]

req_list = []
nq = 3
weights = [0.2]
default_limit = 10
vectors_to_search = []

for i in range(len(field_names)-1):
# 4. generate search data
vectors_to_search = rng.random((nq, dim))
search_param = {
"data": vectors_to_search,
"anns_field": field_names[i],
"param": {"metric_type": "L2"},
"limit": default_limit,
"expr": "random > 0.5"}
req = AnnSearchRequest(**search_param)
req_list.append(req)

hybrid_res = collection.hybrid_search(req_list, WeightedRanker(*weights), default_limit, output_fields=["string"],
group_by_field="string", group_size=2, rank_group_scorer="max", group_strict_size="False")

print("rank by WightedRanker")
for hits in hybrid_res:
print(f" hybrid search hit_size: {len(hits)}")
for hit in hits:
print(f" hybrid search hit: {hit}")



'''
print("rank by RRFRanker")
hybrid_res = collection.hybrid_search(req_list, RRFRanker(), default_limit, output_fields=["string"],
group_by_field="string", group_size=5)
for hits in hybrid_res:
for hit in hits:
print(f" hybrid search hit: {hit}")
'''
99 changes: 99 additions & 0 deletions examples/milvus_client/iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from pymilvus.milvus_client.milvus_client import MilvusClient
from pymilvus import (
FieldSchema, CollectionSchema, DataType,
)
import numpy as np

collection_name = "test_milvus_client_iterator"
prepare_new_data = True
clean_exist = True

USER_ID = "id"
AGE = "age"
DEPOSIT = "deposit"
PICTURE = "picture"
DIM = 8
NUM_ENTITIES = 10000
rng = np.random.default_rng(seed=19530)


def test_query_iterator(milvus_client: MilvusClient):
# test query iterator
expr = f"10 <= {AGE} <= 25"
output_fields = [USER_ID, AGE]
queryIt = milvus_client.query_iterator(collection_name, filter=expr, batch_size=50, output_fields=output_fields)
page_idx = 0
while True:
res = queryIt.next()
if len(res) == 0:
print("query iteration finished, close")
queryIt.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")

def test_search_iterator(milvus_client: MilvusClient):
vector_to_search = rng.random((1, DIM), np.float32)
search_iterator = milvus_client.search_iterator(collection_name, data=vector_to_search, batch_size=100, anns_field=PICTURE)

page_idx = 0
while True:
res = search_iterator.next()
if len(res) == 0:
print("query iteration finished, close")
search_iterator.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")


def main():
milvus_client = MilvusClient("http://localhost:19530")
if milvus_client.has_collection(collection_name) and clean_exist:
milvus_client.drop_collection(collection_name)
print(f"dropped existed collection{collection_name}")

if not milvus_client.has_collection(collection_name):
fields = [
FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False),
FieldSchema(name=AGE, dtype=DataType.INT64),
FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE),
FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM)
]
schema = CollectionSchema(fields)
milvus_client.create_collection(collection_name, dimension=DIM, schema=schema)

if prepare_new_data:
entities = []
for i in range(NUM_ENTITIES):
entity = {
USER_ID: i,
AGE: (i % 100),
DEPOSIT: float(i),
PICTURE: rng.random((1, DIM))[0]
}
entities.append(entity)
milvus_client.insert(collection_name, entities)
milvus_client.flush(collection_name)
print(f"Finish flush collections:{collection_name}")

index_params = milvus_client.prepare_index_params()

index_params.add_index(
field_name=PICTURE,
index_type='IVF_FLAT',
metric_type='L2',
params={"nlist": 1024}
)
milvus_client.create_index(collection_name, index_params)
milvus_client.load_collection(collection_name)
#test_query_iterator(milvus_client)
test_search_iterator(milvus_client)


if __name__ == '__main__':
main()
80 changes: 80 additions & 0 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from pymilvus.orm import utility
from pymilvus.orm.collection import CollectionSchema
from pymilvus.orm.connections import connections
from pymilvus.orm.constants import UNLIMITED
from pymilvus.orm.iterator import QueryIterator, SearchIterator
from pymilvus.orm.types import DataType

from .index import IndexParams
Expand Down Expand Up @@ -480,6 +482,84 @@ def query(

return res

def query_iterator(
self,
collection_name: str,
batch_size: Optional[int] = 1000,
limit: Optional[int] = UNLIMITED,
filter: Optional[str] = "",
output_fields: Optional[List[str]] = None,
partition_names: Optional[List[str]] = None,
timeout: Optional[float] = None,
**kwargs,
):
if filter is not None and not isinstance(filter, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter))

conn = self._get_connection()
# set up schema for iterator
try:
schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)
except Exception as ex:
logger.error("Failed to describe collection: %s", collection_name)
raise ex from ex

return QueryIterator(
connection=conn,
collection_name=collection_name,
batch_size=batch_size,
limit=limit,
expr=filter,
output_fields=output_fields,
partition_names=partition_names,
schema=schema_dict,
timeout=timeout,
**kwargs,
)

def search_iterator(
self,
collection_name: str,
data: Union[List[list], list],
batch_size: Optional[int] = 1000,
filter: Optional[str] = None,
limit: Optional[int] = UNLIMITED,
output_fields: Optional[List[str]] = None,
search_params: Optional[dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
round_decimal: int = -1,
**kwargs,
):
if filter is not None and not isinstance(filter, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter))

conn = self._get_connection()
# set up schema for iterator
try:
schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)
except Exception as ex:
logger.error("Failed to describe collection: %s", collection_name)
raise ex from ex

return SearchIterator(
connection=self._get_connection(),
collection_name=collection_name,
data=data,
ann_field=anns_field,
param=search_params,
batch_size=batch_size,
limit=limit,
expr=filter,
partition_names=partition_names,
output_fields=output_fields,
timeout=timeout,
round_decimal=round_decimal,
schema=schema_dict,
**kwargs,
)

def get(
self,
collection_name: str,
Expand Down

0 comments on commit df3625d

Please sign in to comment.