Skip to content

Commit

Permalink
enhance: [2.4] Expose GetIndexBuildProgress and bump version
Browse files Browse the repository at this point in the history
Cherry pick milvus-io#833

With extra change:
- Refine unit test of `GetIndexBuildProgress`
- Change internal API to `DescribeIndex`
- Bump version to v2.4.2

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia committed Oct 24, 2024
1 parent 9cc0d6c commit 5a6dd60
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 179 deletions.
29 changes: 13 additions & 16 deletions client/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strconv"
"time"

"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
Expand Down Expand Up @@ -104,7 +105,8 @@ func getIndexDef(opts ...IndexOption) indexDef {
// CreateIndex create index for collection
// Deprecated please use CreateIndexV2 instead.
func (c *GrpcClient) CreateIndex(ctx context.Context, collName string, fieldName string,
idx entity.Index, async bool, opts ...IndexOption) error {
idx entity.Index, async bool, opts ...IndexOption,
) error {
if c.Service == nil {
return ErrClientNotReady
}
Expand Down Expand Up @@ -219,7 +221,7 @@ func (c *GrpcClient) DropIndex(ctx context.Context, collName string, fieldName s
idxDef := getIndexDef(opts...)
req := &milvuspb.DropIndexRequest{
Base: idxDef.MsgBase,
DbName: "", //reserved,
DbName: "", // reserved,
CollectionName: collName,
FieldName: fieldName,
IndexName: idxDef.name,
Expand Down Expand Up @@ -267,25 +269,20 @@ func (c *GrpcClient) GetIndexBuildProgress(ctx context.Context, collName string,
if c.Service == nil {
return 0, 0, ErrClientNotReady
}
if err := c.checkCollField(ctx, collName, fieldName); err != nil {
return 0, 0, err
}

idxDef := getIndexDef(opts...)
req := &milvuspb.GetIndexBuildProgressRequest{
DbName: "",
CollectionName: collName,
FieldName: fieldName,
IndexName: idxDef.name,
}
resp, err := c.Service.GetIndexBuildProgress(ctx, req)
results, err := c.describeIndex(ctx, collName, fieldName, opts...)
if err != nil {
return 0, 0, err
}
if err = handleRespStatus(resp.GetStatus()); err != nil {
return 0, 0, err
if len(results) == 0 {
return 0, 0, errors.New("index not found")
}

idxDesc := results[0]
if idxDesc.GetState() == commonpb.IndexState_Failed {
return 0, 0, errors.Newf("index build failed: %s", idxDesc.IndexStateFailReason)
}
return resp.GetTotalRows(), resp.GetIndexedRows(), nil
return idxDesc.GetTotalRows(), idxDesc.GetIndexedRows(), nil
}

func (c *GrpcClient) describeIndex(ctx context.Context, collName string, fieldName string, opts ...IndexOption) ([]*milvuspb.IndexDescription, error) {
Expand Down
144 changes: 91 additions & 53 deletions client/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"context"
"fmt"
"math/rand"
"testing"
"time"
Expand All @@ -13,6 +14,8 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)

func TestGrpcClientCreateIndex(t *testing.T) {
Expand Down Expand Up @@ -151,67 +154,102 @@ func TestGrpcClientDescribeIndex(t *testing.T) {
})
}

func TestGrpcGetIndexBuildProgress(t *testing.T) {
ctx := context.Background()
mockServer.SetInjection(MHasCollection, hasCollectionDefault)
mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))
type IndexSuite struct {
MockSuiteBase
}

tc := testClient(ctx, t)
c := tc.(*GrpcClient) // since GetIndexBuildProgress is not exposed
func (s *IndexSuite) TestGetIndexBuildProgress() {
c := s.client
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

collectionName := fmt.Sprintf("coll_%s", randStr(6))
fieldName := fmt.Sprintf("field_%d", rand.Int31n(10))
indexName := fmt.Sprintf("index_%s", randStr(4))

s.Run("normal_case", func() {
totalRows := rand.Int63n(10000)
indexedRows := rand.Int63n(totalRows)

defer s.resetMock()

s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
s.Equal(collectionName, dir.GetCollectionName())
s.Equal(fieldName, dir.GetFieldName())
s.Equal(indexName, dir.GetIndexName())
return &milvuspb.DescribeIndexResponse{
Status: s.getSuccessStatus(),
IndexDescriptions: []*milvuspb.IndexDescription{
{
IndexName: indexName,
TotalRows: totalRows,
IndexedRows: indexedRows,
State: commonpb.IndexState_InProgress,
},
},
}, nil
}).Once()

t.Run("normal get index build progress", func(t *testing.T) {
var total, built int64
totalResult, indexedResult, err := c.GetIndexBuildProgress(ctx, collectionName, fieldName, WithIndexName(indexName))
s.NoError(err)
s.Equal(totalRows, totalResult)
s.Equal(indexedRows, indexedResult)
})

mockServer.SetInjection(MGetIndexBuildProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) {
req, ok := raw.(*milvuspb.GetIndexBuildProgressRequest)
if !ok {
t.FailNow()
}
assert.Equal(t, testCollectionName, req.GetCollectionName())
resp := &milvuspb.GetIndexBuildProgressResponse{
TotalRows: total,
IndexedRows: built,
}
s, err := SuccessStatus()
resp.Status = s
return resp, err
})
s.Run("index_not_found", func() {
defer s.resetMock()

s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
s.Equal(collectionName, dir.GetCollectionName())
s.Equal(fieldName, dir.GetFieldName())
s.Equal(indexName, dir.GetIndexName())
return &milvuspb.DescribeIndexResponse{
Status: s.getSuccessStatus(),
IndexDescriptions: []*milvuspb.IndexDescription{},
}, nil
}).Once()

_, _, err := c.GetIndexBuildProgress(ctx, collectionName, fieldName, WithIndexName(indexName))
s.Error(err)
})

total = rand.Int63n(1000)
built = rand.Int63n(total)
rt, rb, err := c.GetIndexBuildProgress(ctx, testCollectionName, testVectorField)
assert.NoError(t, err)
assert.Equal(t, total, rt)
assert.Equal(t, built, rb)
s.Run("build_failed", func() {
defer s.resetMock()

s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
s.Equal(collectionName, dir.GetCollectionName())
s.Equal(fieldName, dir.GetFieldName())
s.Equal(indexName, dir.GetIndexName())
return &milvuspb.DescribeIndexResponse{
Status: s.getSuccessStatus(),
IndexDescriptions: []*milvuspb.IndexDescription{
{
IndexName: indexName,
State: commonpb.IndexState_Failed,
},
},
}, nil
}).Once()

_, _, err := c.GetIndexBuildProgress(ctx, collectionName, fieldName, WithIndexName(indexName))
s.Error(err)
})

t.Run("Service return errors", func(t *testing.T) {
defer mockServer.DelInjection(MGetIndexBuildProgress)
mockServer.SetInjection(MGetIndexBuildProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) {
_, ok := raw.(*milvuspb.GetIndexBuildProgressRequest)
if !ok {
t.FailNow()
}
resp := &milvuspb.GetIndexBuildProgressResponse{}
return resp, errors.New("mockServer.d error")
})
s.Run("server_error", func() {
defer s.resetMock()

_, _, err := c.GetIndexBuildProgress(ctx, testCollectionName, testVectorField)
assert.Error(t, err)
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
s.Equal(collectionName, dir.GetCollectionName())
s.Equal(fieldName, dir.GetFieldName())
s.Equal(indexName, dir.GetIndexName())
return nil, errors.New("mocked")
}).Once()

mockServer.SetInjection(MGetIndexBuildProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) {
_, ok := raw.(*milvuspb.GetIndexBuildProgressRequest)
if !ok {
t.FailNow()
}
resp := &milvuspb.GetIndexBuildProgressResponse{}
resp.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}
return resp, nil
})
_, _, err = c.GetIndexBuildProgress(ctx, testCollectionName, testVectorField)
assert.Error(t, err)
_, _, err := c.GetIndexBuildProgress(ctx, collectionName, fieldName, WithIndexName(indexName))
s.Error(err)
})
}

func TestIndex(t *testing.T) {
suite.Run(t, new(IndexSuite))
}
2 changes: 1 addition & 1 deletion common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ package common

const (
// SDKVersion const value for current version
SDKVersion = `v2.4.1`
SDKVersion = `v2.4.2`
)
Loading

0 comments on commit 5a6dd60

Please sign in to comment.