diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 2cb2395e81e21..bb1591f6acbe0 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1382,7 +1382,7 @@ func (t *dropPartitionTask) PreExecute(ctx context.Context) error { return err } if collLoaded { - loaded, err := isPartitionLoaded(ctx, t.queryCoord, collID, []int64{partID}) + loaded, err := isPartitionLoaded(ctx, t.queryCoord, collID, partID) if err != nil { return err } diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 14a335e5bbc74..0968d26f1ea8d 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -47,7 +47,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/crypto" - "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" @@ -1299,11 +1298,11 @@ func isCollectionLoaded(ctx context.Context, qc types.QueryCoordClient, collID i return false, nil } -func isPartitionLoaded(ctx context.Context, qc types.QueryCoordClient, collID int64, partIDs []int64) (bool, error) { +func isPartitionLoaded(ctx context.Context, qc types.QueryCoordClient, collID int64, partID int64) (bool, error) { // get all loading collections resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ CollectionID: collID, - PartitionIDs: partIDs, + PartitionIDs: []int64{partID}, }) if err := merr.CheckRPCCall(resp, err); err != nil { // qc returns error if partition not loaded @@ -1313,7 +1312,7 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoordClient, collID in return false, err } - return funcutil.SliceSetEqual(partIDs, resp.GetPartitionIDs()), nil + return true, nil } func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg, inInsert bool) error { diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 491d6cda6db4b..b7418595d51b1 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -1063,7 +1063,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { Status: merr.Success(), PartitionIDs: []int64{partID}, }, nil) - loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID}) + loaded, err := isPartitionLoaded(ctx, qc, collID, partID) assert.NoError(t, err) assert.True(t, loaded) }) @@ -1088,7 +1088,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { Status: merr.Success(), PartitionIDs: []int64{partID}, }, errors.New("error")) - loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID}) + loaded, err := isPartitionLoaded(ctx, qc, collID, partID) assert.Error(t, err) assert.False(t, loaded) }) @@ -1116,7 +1116,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { }, PartitionIDs: []int64{partID}, }, nil) - loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID}) + loaded, err := isPartitionLoaded(ctx, qc, collID, partID) assert.Error(t, err) assert.False(t, loaded) }) diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index c1bde68331e23..cd07000f0ac7f 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -161,15 +161,16 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions if percentage < 0 { err := meta.GlobalFailedLoadCache.Get(req.GetCollectionID()) if err != nil { - status := merr.Status(err) - log.Warn("show partition failed", zap.Error(err)) + partitionErr := merr.WrapErrPartitionNotLoaded(partitionID, err.Error()) + status := merr.Status(partitionErr) + log.Warn("show partition failed", zap.Error(partitionErr)) return &querypb.ShowPartitionsResponse{ Status: status, }, nil } err = merr.WrapErrPartitionNotLoaded(partitionID) - log.Warn("show partitions failed", zap.Error(err)) + log.Warn("show partition failed", zap.Error(err)) return &querypb.ShowPartitionsResponse{ Status: merr.Status(err), }, nil diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 4025ed11e71f5..b0860ae3d1108 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -26,6 +26,7 @@ import ( "github.com/cockroachdb/errors" "github.com/samber/lo" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -315,7 +316,8 @@ func (suite *ServiceSuite) TestShowPartitions() { meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10)) resp, err = server.ShowPartitions(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) + err := merr.CheckRPCCall(resp, err) + assert.True(suite.T(), errors.Is(err, merr.ErrPartitionNotLoaded)) meta.GlobalFailedLoadCache.Remove(collection) err = suite.meta.CollectionManager.PutCollection(colBak) suite.NoError(err) @@ -327,7 +329,8 @@ func (suite *ServiceSuite) TestShowPartitions() { meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10)) resp, err = server.ShowPartitions(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) + err := merr.CheckRPCCall(resp, err) + assert.True(suite.T(), errors.Is(err, merr.ErrPartitionNotLoaded)) meta.GlobalFailedLoadCache.Remove(collection) err = suite.meta.CollectionManager.PutPartition(parBak) suite.NoError(err) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 26d1801f2e1f7..1bc53cbb43ce5 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -40,14 +40,14 @@ func Code(err error) int32 { } cause := errors.Cause(err) - switch cause := cause.(type) { + switch specificErr := cause.(type) { case milvusError: - return cause.code() + return specificErr.code() default: - if errors.Is(cause, context.Canceled) { + if errors.Is(specificErr, context.Canceled) { return CanceledCode - } else if errors.Is(cause, context.DeadlineExceeded) { + } else if errors.Is(specificErr, context.DeadlineExceeded) { return TimeoutCode } else { return errUnexpected.code()