diff --git a/internal/querycoordv2/balance/streaming_query_node_channel_helper.go b/internal/querycoordv2/balance/streaming_query_node_channel_helper.go index f1771009f4b69..53c3897633a68 100644 --- a/internal/querycoordv2/balance/streaming_query_node_channel_helper.go +++ b/internal/querycoordv2/balance/streaming_query_node_channel_helper.go @@ -15,6 +15,8 @@ func assignChannelToWALLocatedFirst( for _, c := range channels { nodeID := snmanager.StaticStreamingNodeManager.GetWALLocated(c.GetChannelName()) // Check if nodeID is in the list of nodeItems + // The nodeID may not be in the nodeItems when multi replica mode. + // Only one replica can be assigned to the node that wal is located. found := false for _, item := range nodeItems { if item.nodeID == nodeID { @@ -45,6 +47,8 @@ func assignChannelToWALLocatedFirstForNodeInfo( for _, c := range channels { nodeID := snmanager.StaticStreamingNodeManager.GetWALLocated(c.GetChannelName()) // Check if nodeID is in the list of nodeItems + // The nodeID may not be in the nodeItems when multi replica mode. + // Only one replica can be assigned to the node that wal is located. found := false for _, item := range nodeItems { if item.ID() == nodeID { diff --git a/internal/querycoordv2/balance/streaming_query_node_channel_helper_test.go b/internal/querycoordv2/balance/streaming_query_node_channel_helper_test.go new file mode 100644 index 0000000000000..0cbf73359cb8a --- /dev/null +++ b/internal/querycoordv2/balance/streaming_query_node_channel_helper_test.go @@ -0,0 +1,91 @@ +package balance + +import ( + "context" + "testing" + + "github.com/milvus-io/milvus/internal/coordinator/snmanager" + "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAssignChannelToWALLocatedFirst(t *testing.T) { + balancer := mock_balancer.NewMockBalancer(t) + snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer) + + balancer.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { + versions := []typeutil.VersionInt64Pair{ + {Global: 1, Local: 2}, + } + pchans := [][]types.PChannelInfoAssigned{ + { + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, + }, + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel2", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"}, + }, + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel3", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 3, Address: "localhost:1"}, + }, + }, + } + for i := 0; i < len(versions); i++ { + cb(versions[i], pchans[i]) + } + <-ctx.Done() + return context.Cause(ctx) + }) + + channels := []*meta.DmChannel{ + {VchannelInfo: &datapb.VchannelInfo{ChannelName: "pchannel_v1"}}, + {VchannelInfo: &datapb.VchannelInfo{ChannelName: "pchannel2_v2"}}, + {VchannelInfo: &datapb.VchannelInfo{ChannelName: "pchannel3_v1"}}, + } + nodeItems := []*nodeItem{ + {nodeID: 1}, + {nodeID: 2}, + } + + notFounChannels, plans := assignChannelToWALLocatedFirst(channels, nodeItems) + assert.Len(t, notFounChannels, 1) + assert.Equal(t, notFounChannels[0].GetChannelName(), "pchannel3_v1") + assert.Len(t, plans, 2) + for _, plan := range plans { + if plan.Channel.GetChannelName() == "pchannel_v1" { + assert.Equal(t, plan.To, int64(1)) + } else { + assert.Equal(t, plan.To, int64(2)) + } + } + + var scoreDelta map[int64]int + nodeInfos := []*session.NodeInfo{ + session.NewNodeInfo(session.ImmutableNodeInfo{NodeID: 1}), + session.NewNodeInfo(session.ImmutableNodeInfo{NodeID: 2}), + } + + notFounChannels, plans, scoreDelta = assignChannelToWALLocatedFirstForNodeInfo(channels, nodeInfos) + assert.Len(t, notFounChannels, 1) + assert.Equal(t, notFounChannels[0].GetChannelName(), "pchannel3_v1") + assert.Len(t, plans, 2) + assert.Len(t, scoreDelta, 2) + for _, plan := range plans { + if plan.Channel.GetChannelName() == "pchannel_v1" { + assert.Equal(t, plan.To, int64(1)) + assert.Equal(t, scoreDelta[1], 1) + } else { + assert.Equal(t, plan.To, int64(2)) + assert.Equal(t, scoreDelta[2], 1) + } + } +} diff --git a/internal/querycoordv2/observers/replica_observer_test.go b/internal/querycoordv2/observers/replica_observer_test.go index 266d731a00d22..a619cdd06ce82 100644 --- a/internal/querycoordv2/observers/replica_observer_test.go +++ b/internal/querycoordv2/observers/replica_observer_test.go @@ -20,16 +20,21 @@ import ( "testing" "time" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/rgpb" + "github.com/milvus-io/milvus/internal/coordinator/snmanager" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" + "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/kv" + "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -52,6 +57,7 @@ type ReplicaObserverSuite struct { } func (suite *ReplicaObserverSuite) SetupSuite() { + streamingutil.SetStreamingServiceEnabled() paramtable.Init() paramtable.Get().Save(Params.QueryCoordCfg.CheckNodeInReplicaInterval.Key, "1") } @@ -196,9 +202,105 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() { }, 30*time.Second, 2*time.Second) } +func (suite *ReplicaObserverSuite) TestCheckSQnodesInReplica() { + balancer := mock_balancer.NewMockBalancer(suite.T()) + snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer) + + change := make(chan struct{}) + balancer.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { + versions := []typeutil.VersionInt64Pair{ + {Global: 1, Local: 2}, + {Global: 1, Local: 3}, + } + pchans := [][]types.PChannelInfoAssigned{ + { + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, + }, + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel2", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"}, + }, + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel3", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 3, Address: "localhost:1"}, + }, + }, + { + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, + }, + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel2", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"}, + }, + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel3", Term: 2}, + Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"}, + }, + }, + } + for i := 0; i < len(versions); i++ { + cb(versions[i], pchans[i]) + <-change + } + <-ctx.Done() + return context.Cause(ctx) + }) + + ctx := context.Background() + err := suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(suite.collectionID, 2)) + suite.NoError(err) + replicas, err := suite.meta.Spawn(ctx, suite.collectionID, map[string]int{ + "rg1": 1, + "rg2": 1, + }, nil) + suite.NoError(err) + suite.Equal(2, len(replicas)) + + suite.Eventually(func() bool { + replica := suite.meta.ReplicaManager.GetByCollection(ctx, suite.collectionID) + total := 0 + for _, r := range replica { + total += r.RWSQNodesCount() + } + return total == 3 + }, 6*time.Second, 2*time.Second) + replica := suite.meta.ReplicaManager.GetByCollection(ctx, suite.collectionID) + nodes := typeutil.NewUniqueSet() + for _, r := range replica { + suite.LessOrEqual(r.RWSQNodesCount(), 2) + suite.Equal(r.ROSQNodesCount(), 0) + nodes.Insert(r.GetRWSQNodes()...) + } + suite.Equal(nodes.Len(), 3) + + close(change) + + suite.Eventually(func() bool { + replica := suite.meta.ReplicaManager.GetByCollection(ctx, suite.collectionID) + total := 0 + for _, r := range replica { + total += r.RWSQNodesCount() + } + return total == 2 + }, 6*time.Second, 2*time.Second) + replica = suite.meta.ReplicaManager.GetByCollection(ctx, suite.collectionID) + nodes = typeutil.NewUniqueSet() + for _, r := range replica { + suite.Equal(r.RWSQNodesCount(), 1) + suite.Equal(r.ROSQNodesCount(), 0) + nodes.Insert(r.GetRWSQNodes()...) + } + suite.Equal(nodes.Len(), 2) +} + func (suite *ReplicaObserverSuite) TearDownSuite() { suite.kv.Close() suite.observer.Stop() + streamingutil.UnsetStreamingServiceEnabled() } func TestReplicaObserver(t *testing.T) {