From 13fce3c43c4e5b6deaea42a203ff835e11913465 Mon Sep 17 00:00:00 2001 From: unfode Date: Tue, 17 Oct 2023 22:46:08 -0400 Subject: [PATCH] Implement SearchByPks Signed-off-by: unfode --- client/client.go | 13 ++- client/client_test.go | 2 +- client/data.go | 200 +++++++++++++++++++++++++++++------------- 3 files changed, 151 insertions(+), 64 deletions(-) diff --git a/client/client.go b/client/client.go index b974b53d..da1e4e72 100644 --- a/client/client.go +++ b/client/client.go @@ -135,8 +135,17 @@ type Client interface { // Upsert column-based data of collection, returns id column values Upsert(ctx context.Context, collName string, partitionName string, columns ...entity.Column) (entity.Column, error) // Search with bool expression - Search(ctx context.Context, collName string, partitions []string, - expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) + Search( + ctx context.Context, collName string, partitions []string, expr string, outputFields []string, + vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, + sp entity.SearchParam, opts ...SearchQueryOptionFunc, + ) ([]SearchResult, error) + // SearchByPks searches using the vectors corresponding to the provided primary keys + SearchByPks( + ctx context.Context, collName string, partitions []string, expr string, outputFields []string, + primaryKeys entity.Column, vectorField string, metricType entity.MetricType, topK int, + sp entity.SearchParam, opts ...SearchQueryOptionFunc, + ) ([]SearchResult, error) // QueryByPks query record by specified primary key(s). QueryByPks(ctx context.Context, collectionName string, partitionNames []string, ids entity.Column, outputFields []string, opts ...SearchQueryOptionFunc) (ResultSet, error) // Query performs query records with boolean expression. diff --git a/client/client_test.go b/client/client_test.go index d369cba3..8f67b8fb 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -149,7 +149,7 @@ func TestGrpcClientNil(t *testing.T) { mt := m.Type // type of function if m.Name == "Close" || m.Name == "Connect" || // skip connect & close m.Name == "UsingDatabase" || // skip use database - m.Name == "Search" || // type alias MetricType treated as string + m.Name == "Search" || m.Name == "SearchByPks" || // type alias MetricType treated as string m.Name == "CalcDistance" || m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ... diff --git a/client/data.go b/client/data.go index 4d18ed4a..00a89647 100644 --- a/client/data.go +++ b/client/data.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/cockroachdb/errors" + "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -35,34 +36,115 @@ const ( ) // Search with bool expression -func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []string, - expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) { +func (c *GrpcClient) Search( + ctx context.Context, collName string, partitions []string, expr string, outputFields []string, vectors []entity.Vector, + vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc, +) ([]SearchResult, error) { if c.Service == nil { return []SearchResult{}, ErrClientNotReady } - var schema *entity.Schema - collInfo, ok := MetaCache.getCollectionInfo(collName) + + _, ok := MetaCache.getCollectionInfo(collName) if !ok { - coll, err := c.DescribeCollection(ctx, collName) + _, err := c.DescribeCollection(ctx, collName) if err != nil { return nil, err } - schema = coll.Schema - } else { - schema = collInfo.Schema } option, err := makeSearchQueryOption(collName, opts...) if err != nil { return nil, err } - // 2. Request milvus Service - req, err := prepareSearchRequest(collName, partitions, expr, outputFields, vectors, vectorField, metricType, topK, sp, option) + + params := sp.Params() + bs, err := json.Marshal(params) if err != nil { return nil, err } - sr := make([]SearchResult, 0, len(vectors)) + searchParams := prepareSearchParamsForSearchRequest( + vectorField, metricType, topK, bs, option, + ) + + req := &milvuspb.SearchRequest{ + DbName: "", + CollectionName: collName, + PartitionNames: partitions, + Dsl: expr, + PlaceholderGroup: vector2PlaceholderGroupBytes(vectors), + DslType: commonpb.DslType_BoolExprV1, + OutputFields: outputFields, + SearchParams: searchParams, + GuaranteeTimestamp: option.GuaranteeTimestamp, + Nq: int64(len(vectors)), + SearchByPrimaryKeys: false, + } + + resp, err := c.Service.Search(ctx, req) + if err != nil { + return nil, err + } + if err := handleRespStatus(resp.GetStatus()); err != nil { + return nil, err + } + + return processSearchResponse(resp, outputFields), nil +} + +func (c *GrpcClient) SearchByPks( + ctx context.Context, collName string, partitions []string, expr string, outputFields []string, + primaryKeys entity.Column, vectorField string, metricType entity.MetricType, topK int, + sp entity.SearchParam, opts ...SearchQueryOptionFunc, +) ([]SearchResult, error) { + if c.Service == nil { + return []SearchResult{}, ErrClientNotReady + } + + if primaryKeys.Len() == 0 { + return nil, errors.New("expected at least one primary key, but got zero") + } + if primaryKeys.Type() != entity.FieldTypeInt64 && primaryKeys.Type() != entity.FieldTypeVarChar { + return nil, errors.New("only int64 and varchar column can be primary key for now") + } + + _, ok := MetaCache.getCollectionInfo(collName) + if !ok { + _, err := c.DescribeCollection(ctx, collName) + if err != nil { + return nil, err + } + } + + option, err := makeSearchQueryOption(collName, opts...) + if err != nil { + return nil, err + } + + params := sp.Params() + bs, err := json.Marshal(params) + if err != nil { + return nil, err + } + + searchParams := prepareSearchParamsForSearchRequest( + vectorField, metricType, topK, bs, option, + ) + + req := &milvuspb.SearchRequest{ + DbName: "", + CollectionName: collName, + PartitionNames: partitions, + Dsl: expr, + PlaceholderGroup: primaryKeysToPlaceholderGroupBytes(primaryKeys), + DslType: commonpb.DslType_BoolExprV1, + OutputFields: outputFields, + SearchParams: searchParams, + GuaranteeTimestamp: option.GuaranteeTimestamp, + Nq: int64(primaryKeys.Len()), + SearchByPrimaryKeys: true, + } + resp, err := c.Service.Search(ctx, req) if err != nil { return nil, err @@ -70,10 +152,33 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s if err := handleRespStatus(resp.GetStatus()); err != nil { return nil, err } - // 3. parse result into result - results := resp.GetResults() + + return processSearchResponse(resp, outputFields), nil +} + +func prepareSearchParamsForSearchRequest( + vectorField string, metricType entity.MetricType, topK int, bs []byte, opt *SearchQueryOption, +) []*commonpb.KeyValuePair { + searchParams := entity.MapKvPairs(map[string]string{ + "anns_field": vectorField, + "topk": fmt.Sprintf("%d", topK), + "params": string(bs), + "metric_type": string(metricType), + "round_decimal": "-1", + ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing), + offsetKey: fmt.Sprintf("%d", opt.Offset), + }) + + return searchParams +} + +func processSearchResponse(response *milvuspb.SearchResults, outputFields []string) []SearchResult { + results := response.GetResults() + + sr := make([]SearchResult, 0, results.GetNumQueries()) offset := 0 fieldDataList := results.GetFieldsData() + for i := 0; i < int(results.GetNumQueries()); i++ { rc := int(results.GetTopks()[i]) // result entry count for current query entry := SearchResult{ @@ -85,14 +190,15 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s offset += rc continue } - entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc) + entry.Fields, entry.Err = parseSearchResult(outputFields, fieldDataList, offset, offset+rc) sr = append(sr, entry) offset += rc } - return sr, nil + + return sr } -func (c *GrpcClient) parseSearchResult(_ *entity.Schema, outputFields []string, fieldDataList []*schemapb.FieldData, _, from, to int) ([]entity.Column, error) { +func parseSearchResult(outputFields []string, fieldDataList []*schemapb.FieldData, from, to int) ([]entity.Column, error) { // duplicated name will have only one column now outputSet := make(map[string]struct{}) for _, output := range outputFields { @@ -208,16 +314,12 @@ func (c *GrpcClient) Query(ctx context.Context, collectionName string, partition return nil, ErrClientNotReady } - var sch *entity.Schema - collInfo, ok := MetaCache.getCollectionInfo(collectionName) + _, ok := MetaCache.getCollectionInfo(collectionName) if !ok { - coll, err := c.DescribeCollection(ctx, collectionName) + _, err := c.DescribeCollection(ctx, collectionName) if err != nil { return nil, err } - sch = coll.Schema - } else { - sch = collInfo.Schema } option, err := makeSearchQueryOption(collectionName, opts...) @@ -254,7 +356,7 @@ func (c *GrpcClient) Query(ctx context.Context, collectionName string, partition fieldsData := resp.GetFieldsData() - columns, err := c.parseSearchResult(sch, outputFields, fieldsData, 0, 0, -1) //entity.FieldDataColumn(fieldData, 0, -1) + columns, err := parseSearchResult(outputFields, fieldsData, 0, -1) //entity.FieldDataColumn(fieldData, 0, -1) if err != nil { return nil, err } @@ -271,47 +373,23 @@ func getPKField(schema *entity.Schema) *entity.Field { return nil } -func getVectorField(schema *entity.Schema) *entity.Field { - for _, f := range schema.Fields { - if f.DataType == entity.FieldTypeFloatVector || f.DataType == entity.FieldTypeBinaryVector { - return f - } - } - return nil -} +func primaryKeysToPlaceholderGroupBytes(primaryKeys entity.Column) []byte { -func prepareSearchRequest(collName string, partitions []string, - expr string, outputFields []string, vectors []entity.Vector, vectorField string, - metricType entity.MetricType, topK int, sp entity.SearchParam, opt *SearchQueryOption) (*milvuspb.SearchRequest, error) { - params := sp.Params() - params[forTuningKey] = opt.ForTuning - bs, err := json.Marshal(params) - if err != nil { - return nil, err - } + queryExpr := PKs2Expr("", primaryKeys) + queryExprBytes := []byte(queryExpr) - searchParams := entity.MapKvPairs(map[string]string{ - "anns_field": vectorField, - "topk": fmt.Sprintf("%d", topK), - "params": string(bs), - "metric_type": string(metricType), - "round_decimal": "-1", - ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing), - offsetKey: fmt.Sprintf("%d", opt.Offset), - }) - req := &milvuspb.SearchRequest{ - DbName: "", - CollectionName: collName, - PartitionNames: partitions, - Dsl: expr, - PlaceholderGroup: vector2PlaceholderGroupBytes(vectors), - DslType: commonpb.DslType_BoolExprV1, - OutputFields: outputFields, - SearchParams: searchParams, - GuaranteeTimestamp: opt.GuaranteeTimestamp, - Nq: int64(len(vectors)), + placeholderGroup := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + { + Tag: "$0", + Type: commonpb.PlaceholderType_None, + Values: [][]byte{queryExprBytes}, + }, + }, } - return req, nil + + bs, _ := proto.Marshal(placeholderGroup) + return bs } // GetPersistentSegmentInfo get persistent segment info