diff --git a/internal/querycoordv2/meta/replica_manager_test.go b/internal/querycoordv2/meta/replica_manager_test.go index e498562be930c..268babcc6ad11 100644 --- a/internal/querycoordv2/meta/replica_manager_test.go +++ b/internal/querycoordv2/meta/replica_manager_test.go @@ -221,7 +221,7 @@ func (suite *ReplicaManagerSuite) TestRecover() { replica := mgr.Get(ctx, 2100) suite.NotNil(replica) suite.EqualValues(1000, replica.GetCollectionID()) - suite.EqualValues([]int64{1, 2, 3}, replica.GetNodes()) + suite.ElementsMatch([]int64{1, 2, 3}, replica.GetNodes()) suite.Len(replica.GetNodes(), len(replica.GetNodes())) for _, node := range replica.GetNodes() { suite.True(replica.Contains(node)) diff --git a/internal/streamingcoord/server/balancer/balancer.go b/internal/streamingcoord/server/balancer/balancer.go index abe35d51ec6ae..98d2bd85bc141 100644 --- a/internal/streamingcoord/server/balancer/balancer.go +++ b/internal/streamingcoord/server/balancer/balancer.go @@ -3,11 +3,16 @@ package balancer import ( "context" + "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var _ Balancer = (*balancerImpl)(nil) +var ( + _ Balancer = (*balancerImpl)(nil) + ErrBalancerClosed = errors.New("balancer is closed") +) // Balancer is a load balancer to balance the load of log node. // Given the balance result to assign or remove channels to corresponding log node. diff --git a/internal/streamingcoord/server/balancer/balancer_impl.go b/internal/streamingcoord/server/balancer/balancer_impl.go index f40bbcb3c62b4..1b8967653a820 100644 --- a/internal/streamingcoord/server/balancer/balancer_impl.go +++ b/internal/streamingcoord/server/balancer/balancer_impl.go @@ -13,6 +13,7 @@ import ( "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -30,7 +31,10 @@ func RecoverBalancer( if err != nil { return nil, errors.Wrap(err, "fail to recover channel manager") } + ctx, cancel := context.WithCancelCause(context.Background()) b := &balancerImpl{ + ctx: ctx, + cancel: cancel, lifetime: typeutil.NewLifetime(), logger: log.With(zap.String("policy", policy)), channelMetaManager: manager, @@ -44,6 +48,8 @@ func RecoverBalancer( // balancerImpl is a implementation of Balancer. type balancerImpl struct { + ctx context.Context + cancel context.CancelCauseFunc lifetime *typeutil.Lifetime logger *log.MLogger channelMetaManager *channel.ChannelManager @@ -58,6 +64,8 @@ func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb func(vers return status.NewOnShutdownError("balancer is closing") } defer b.lifetime.Done() + + ctx, _ = contextutil.MergeContext(ctx, b.ctx) return b.channelMetaManager.WatchAssignmentResult(ctx, cb) } @@ -93,6 +101,8 @@ func (b *balancerImpl) sendRequestAndWaitFinish(ctx context.Context, newReq *req // Close close the balancer. func (b *balancerImpl) Close() { b.lifetime.SetState(typeutil.LifetimeStateStopped) + // cancel all watch opeartion by context. + b.cancel(ErrBalancerClosed) b.lifetime.Wait() b.backgroundTaskNotifier.Cancel() diff --git a/internal/streamingcoord/server/balancer/balancer_test.go b/internal/streamingcoord/server/balancer/balancer_test.go index f0a738044c9b5..3a954edc31f9b 100644 --- a/internal/streamingcoord/server/balancer/balancer_test.go +++ b/internal/streamingcoord/server/balancer/balancer_test.go @@ -3,6 +3,7 @@ package balancer_test import ( "context" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -16,6 +17,7 @@ import ( "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -91,7 +93,6 @@ func TestBalancer(t *testing.T) { b, err := balancer.RecoverBalancer(ctx, "pchannel_count_fair") assert.NoError(t, err) assert.NotNil(t, b) - defer b.Close() b.MarkAsUnavailable(ctx, []types.PChannelInfo{{ Name: "test-channel-1", @@ -113,4 +114,18 @@ func TestBalancer(t *testing.T) { return nil }) assert.ErrorIs(t, err, doneErr) + + // create a inifite block watcher and can be interrupted by close of balancer. + f := syncutil.NewFuture[error]() + go func() { + err := b.WatchChannelAssignments(context.Background(), func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error { + return nil + }) + f.Set(err) + }() + time.Sleep(20 * time.Millisecond) + assert.False(t, f.Ready()) + + b.Close() + assert.ErrorIs(t, f.Get(), balancer.ErrBalancerClosed) } diff --git a/pkg/util/contextutil/context_util.go b/pkg/util/contextutil/context_util.go index 8cf699b43079b..2bded437d1ec5 100644 --- a/pkg/util/contextutil/context_util.go +++ b/pkg/util/contextutil/context_util.go @@ -121,3 +121,15 @@ func WithDeadlineCause(parent context.Context, deadline time.Time, err error) (c cancel(context.Canceled) } } + +// MergeContext create a cancellation context that cancels when any of the given contexts are canceled. +func MergeContext(ctx1 context.Context, ctx2 context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancelCause(ctx1) + stop := context.AfterFunc(ctx2, func() { + cancel(context.Cause(ctx2)) + }) + return ctx, func() { + stop() + cancel(context.Canceled) + } +}