Skip to content

Commit

Permalink
fix: add unittest
Browse files Browse the repository at this point in the history
Signed-off-by: chyezh <[email protected]>
  • Loading branch information
chyezh committed Dec 29, 2024
1 parent 7d4920e commit e7cebb7
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 3 deletions.
2 changes: 1 addition & 1 deletion internal/querycoordv2/meta/replica_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (suite *ReplicaManagerSuite) TestRecover() {
replica := mgr.Get(ctx, 2100)
suite.NotNil(replica)
suite.EqualValues(1000, replica.GetCollectionID())
suite.EqualValues([]int64{1, 2, 3}, replica.GetNodes())
suite.ElementsMatch([]int64{1, 2, 3}, replica.GetNodes())
suite.Len(replica.GetNodes(), len(replica.GetNodes()))
for _, node := range replica.GetNodes() {
suite.True(replica.Contains(node))
Expand Down
7 changes: 6 additions & 1 deletion internal/streamingcoord/server/balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ package balancer
import (
"context"

"github.com/cockroachdb/errors"

"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

var _ Balancer = (*balancerImpl)(nil)
var (
_ Balancer = (*balancerImpl)(nil)
ErrBalancerClosed = errors.New("balancer is closed")
)

// Balancer is a load balancer to balance the load of log node.
// Given the balance result to assign or remove channels to corresponding log node.
Expand Down
10 changes: 10 additions & 0 deletions internal/streamingcoord/server/balancer/balancer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/contextutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/syncutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
Expand All @@ -30,7 +31,10 @@ func RecoverBalancer(
if err != nil {
return nil, errors.Wrap(err, "fail to recover channel manager")
}
ctx, cancel := context.WithCancelCause(context.Background())
b := &balancerImpl{
ctx: ctx,
cancel: cancel,
lifetime: typeutil.NewLifetime(),
logger: log.With(zap.String("policy", policy)),
channelMetaManager: manager,
Expand All @@ -44,6 +48,8 @@ func RecoverBalancer(

// balancerImpl is a implementation of Balancer.
type balancerImpl struct {
ctx context.Context
cancel context.CancelCauseFunc
lifetime *typeutil.Lifetime
logger *log.MLogger
channelMetaManager *channel.ChannelManager
Expand All @@ -58,6 +64,8 @@ func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb func(vers
return status.NewOnShutdownError("balancer is closing")
}
defer b.lifetime.Done()

ctx, _ = contextutil.MergeContext(ctx, b.ctx)
return b.channelMetaManager.WatchAssignmentResult(ctx, cb)
}

Expand Down Expand Up @@ -93,6 +101,8 @@ func (b *balancerImpl) sendRequestAndWaitFinish(ctx context.Context, newReq *req
// Close close the balancer.
func (b *balancerImpl) Close() {
b.lifetime.SetState(typeutil.LifetimeStateStopped)
// cancel all watch opeartion by context.
b.cancel(ErrBalancerClosed)
b.lifetime.Wait()

b.backgroundTaskNotifier.Cancel()
Expand Down
17 changes: 16 additions & 1 deletion internal/streamingcoord/server/balancer/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package balancer_test
import (
"context"
"testing"
"time"

"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
Expand All @@ -16,6 +17,7 @@ import (
"github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/syncutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

Expand Down Expand Up @@ -91,7 +93,6 @@ func TestBalancer(t *testing.T) {
b, err := balancer.RecoverBalancer(ctx, "pchannel_count_fair")
assert.NoError(t, err)
assert.NotNil(t, b)
defer b.Close()

b.MarkAsUnavailable(ctx, []types.PChannelInfo{{
Name: "test-channel-1",
Expand All @@ -113,4 +114,18 @@ func TestBalancer(t *testing.T) {
return nil
})
assert.ErrorIs(t, err, doneErr)

// create a inifite block watcher and can be interrupted by close of balancer.
f := syncutil.NewFuture[error]()
go func() {
err := b.WatchChannelAssignments(context.Background(), func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error {
return nil
})
f.Set(err)
}()
time.Sleep(20 * time.Millisecond)
assert.False(t, f.Ready())

b.Close()
assert.ErrorIs(t, f.Get(), balancer.ErrBalancerClosed)
}
12 changes: 12 additions & 0 deletions pkg/util/contextutil/context_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,15 @@ func WithDeadlineCause(parent context.Context, deadline time.Time, err error) (c
cancel(context.Canceled)
}
}

// MergeContext create a cancellation context that cancels when any of the given contexts are canceled.
func MergeContext(ctx1 context.Context, ctx2 context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancelCause(ctx1)
stop := context.AfterFunc(ctx2, func() {
cancel(context.Cause(ctx2))
})
return ctx, func() {
stop()
cancel(context.Canceled)
}

Check warning on line 134 in pkg/util/contextutil/context_util.go

View check run for this annotation

Codecov / codecov/patch

pkg/util/contextutil/context_util.go#L132-L134

Added lines #L132 - L134 were not covered by tests
}

0 comments on commit e7cebb7

Please sign in to comment.