diff --git a/client/client_grpc_collection.go b/client/client_grpc_collection.go index 30be56e7..86cf4364 100644 --- a/client/client_grpc_collection.go +++ b/client/client_grpc_collection.go @@ -373,15 +373,15 @@ func (c *GrpcClient) LoadCollection(ctx context.Context, collName string, async default: } - coll, err := c.ShowCollection(ctx, collName) + progress, err := c.GetLoadingProgress(ctx, collName, nil) if err != nil { return err } - if coll.Loaded { - break + if progress == 100 { + return nil } - time.Sleep(200 * time.Millisecond) // TODO change to configuration + time.Sleep(500 * time.Millisecond) } } return nil diff --git a/client/client_grpc_collection_test.go b/client/client_grpc_collection_test.go index 7e1c1c60..68b0b465 100644 --- a/client/client_grpc_collection_test.go +++ b/client/client_grpc_collection_test.go @@ -389,8 +389,8 @@ func TestGrpcClientLoadCollection(t *testing.T) { start := time.Now() mockServer.SetInjection(MShowCollections, func(_ context.Context, raw proto.Message) (proto.Message, error) { - req, ok := raw.(*server.ShowCollectionsRequest) - r := &server.ShowCollectionsResponse{} + req, ok := raw.(*server.GetLoadingProgressRequest) + r := &server.GetLoadingProgressResponse{} if !ok || req == nil { s, err := BadRequestStatus() r.Status = s @@ -398,16 +398,16 @@ func TestGrpcClientLoadCollection(t *testing.T) { } s, err := SuccessStatus() r.Status = s - r.CollectionIds = []int64{1} var perc int64 if time.Since(start) > time.Duration(loadTime)*time.Millisecond { t.Log("passed") perc = 100 passed = true } - r.InMemoryPercentages = []int64{perc} + r.Progress = perc return r, err }) + assert.Nil(t, c.LoadCollection(ctx, testCollectionName, false)) assert.True(t, passed) @@ -418,7 +418,7 @@ func TestGrpcClientLoadCollection(t *testing.T) { assert.NotNil(t, c.LoadCollection(quickCtx, testCollectionName, false)) // remove injection - mockServer.DelInjection(MShowCollections) + mockServer.DelInjection(MGetLoadingProgress) }) t.Run("Load default replica", func(t *testing.T) { mockServer.SetInjection(MLoadCollection, func(ctx context.Context, raw proto.Message) (proto.Message, error) { diff --git a/client/client_grpc_partition.go b/client/client_grpc_partition.go index 87612331..4324b27c 100644 --- a/client/client_grpc_partition.go +++ b/client/client_grpc_partition.go @@ -149,23 +149,6 @@ func (c *GrpcClient) LoadPartitions(ctx context.Context, collName string, partit return err } } - partitions, err := c.ShowPartitions(ctx, collName) - if err != nil { - return err - } - m := make(map[string]int64) - for _, partition := range partitions { - m[partition.Name] = partition.ID - } - // load partitions ids - ids := make(map[int64]struct{}) - for _, partitionName := range partitionNames { - id, has := m[partitionName] - if !has { - return fmt.Errorf("Collection %s does not has partitions %s", collName, partitionName) - } - ids[id] = struct{}{} - } req := &server.LoadPartitionsRequest{ DbName: "", // reserved @@ -187,28 +170,16 @@ func (c *GrpcClient) LoadPartitions(ctx context.Context, collName string, partit return errors.New("context deadline exceeded") default: } - partitions, err := c.ShowPartitions(ctx, collName) + percentage, err := c.GetLoadingProgress(ctx, collName, partitionNames) if err != nil { return err } - foundLoading := false - loaded := 0 - for _, partition := range partitions { - if _, has := ids[partition.ID]; !has { - continue - } - if !partition.Loaded { - //Not loaded - foundLoading = true - break - } - loaded++ - } - if foundLoading || loaded < len(partitionNames) { - time.Sleep(time.Millisecond * 100) - continue + + if percentage == 100 { + return nil } - break + + time.Sleep(500 * time.Millisecond) } }