Skip to content

Commit

Permalink
enhance: add more unittest
Browse files Browse the repository at this point in the history
Signed-off-by: chyezh <[email protected]>
  • Loading branch information
chyezh committed Dec 30, 2024
1 parent 292f824 commit ba62e5d
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ func assignChannelToWALLocatedFirst(
for _, c := range channels {
nodeID := snmanager.StaticStreamingNodeManager.GetWALLocated(c.GetChannelName())
// Check if nodeID is in the list of nodeItems
// The nodeID may not be in the nodeItems when multi replica mode.
// Only one replica can be assigned to the node that wal is located.
found := false
for _, item := range nodeItems {
if item.nodeID == nodeID {
Expand Down Expand Up @@ -45,6 +47,8 @@ func assignChannelToWALLocatedFirstForNodeInfo(
for _, c := range channels {
nodeID := snmanager.StaticStreamingNodeManager.GetWALLocated(c.GetChannelName())
// Check if nodeID is in the list of nodeItems
// The nodeID may not be in the nodeItems when multi replica mode.
// Only one replica can be assigned to the node that wal is located.
found := false
for _, item := range nodeItems {
if item.ID() == nodeID {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package balance

import (
"context"
"testing"

"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestAssignChannelToWALLocatedFirst(t *testing.T) {
balancer := mock_balancer.NewMockBalancer(t)
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)

balancer.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
versions := []typeutil.VersionInt64Pair{
{Global: 1, Local: 2},
}
pchans := [][]types.PChannelInfoAssigned{
{
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel2", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel3", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 3, Address: "localhost:1"},
},
},
}
for i := 0; i < len(versions); i++ {
cb(versions[i], pchans[i])
}
<-ctx.Done()
return context.Cause(ctx)
})

channels := []*meta.DmChannel{
{VchannelInfo: &datapb.VchannelInfo{ChannelName: "pchannel_v1"}},
{VchannelInfo: &datapb.VchannelInfo{ChannelName: "pchannel2_v2"}},
{VchannelInfo: &datapb.VchannelInfo{ChannelName: "pchannel3_v1"}},
}
nodeItems := []*nodeItem{
{nodeID: 1},
{nodeID: 2},
}

notFounChannels, plans := assignChannelToWALLocatedFirst(channels, nodeItems)
assert.Len(t, notFounChannels, 1)
assert.Equal(t, notFounChannels[0].GetChannelName(), "pchannel3_v1")
assert.Len(t, plans, 2)
for _, plan := range plans {
if plan.Channel.GetChannelName() == "pchannel_v1" {
assert.Equal(t, plan.To, int64(1))
} else {
assert.Equal(t, plan.To, int64(2))
}
}

var scoreDelta map[int64]int
nodeInfos := []*session.NodeInfo{
session.NewNodeInfo(session.ImmutableNodeInfo{NodeID: 1}),
session.NewNodeInfo(session.ImmutableNodeInfo{NodeID: 2}),
}

notFounChannels, plans, scoreDelta = assignChannelToWALLocatedFirstForNodeInfo(channels, nodeInfos)
assert.Len(t, notFounChannels, 1)
assert.Equal(t, notFounChannels[0].GetChannelName(), "pchannel3_v1")
assert.Len(t, plans, 2)
assert.Len(t, scoreDelta, 2)
for _, plan := range plans {
if plan.Channel.GetChannelName() == "pchannel_v1" {
assert.Equal(t, plan.To, int64(1))
assert.Equal(t, scoreDelta[1], 1)
} else {
assert.Equal(t, plan.To, int64(2))
assert.Equal(t, scoreDelta[2], 1)
}
}
}
102 changes: 102 additions & 0 deletions internal/querycoordv2/observers/replica_observer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,21 @@ import (
"testing"
"time"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"

"github.com/milvus-io/milvus-proto/go-api/v2/rgpb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/querycoord"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
. "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/kv"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
Expand All @@ -52,6 +57,7 @@ type ReplicaObserverSuite struct {
}

func (suite *ReplicaObserverSuite) SetupSuite() {
streamingutil.SetStreamingServiceEnabled()
paramtable.Init()
paramtable.Get().Save(Params.QueryCoordCfg.CheckNodeInReplicaInterval.Key, "1")
}
Expand Down Expand Up @@ -196,9 +202,105 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() {
}, 30*time.Second, 2*time.Second)
}

func (suite *ReplicaObserverSuite) TestCheckSQnodesInReplica() {
balancer := mock_balancer.NewMockBalancer(suite.T())
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)

change := make(chan struct{})
balancer.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
versions := []typeutil.VersionInt64Pair{
{Global: 1, Local: 2},
{Global: 1, Local: 3},
}
pchans := [][]types.PChannelInfoAssigned{
{
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel2", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel3", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 3, Address: "localhost:1"},
},
},
{
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel2", Term: 1},
Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"},
},
types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "pchannel3", Term: 2},
Node: types.StreamingNodeInfo{ServerID: 2, Address: "localhost:1"},
},
},
}
for i := 0; i < len(versions); i++ {
cb(versions[i], pchans[i])
<-change
}
<-ctx.Done()
return context.Cause(ctx)
})

ctx := context.Background()
err := suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(suite.collectionID, 2))
suite.NoError(err)
replicas, err := suite.meta.Spawn(ctx, suite.collectionID, map[string]int{
"rg1": 1,
"rg2": 1,
}, nil)
suite.NoError(err)
suite.Equal(2, len(replicas))

suite.Eventually(func() bool {
replica := suite.meta.ReplicaManager.GetByCollection(ctx, suite.collectionID)
total := 0
for _, r := range replica {
total += r.RWSQNodesCount()
}
return total == 3
}, 6*time.Second, 2*time.Second)
replica := suite.meta.ReplicaManager.GetByCollection(ctx, suite.collectionID)
nodes := typeutil.NewUniqueSet()
for _, r := range replica {
suite.LessOrEqual(r.RWSQNodesCount(), 2)
suite.Equal(r.ROSQNodesCount(), 0)
nodes.Insert(r.GetRWSQNodes()...)
}
suite.Equal(nodes.Len(), 3)

close(change)

suite.Eventually(func() bool {
replica := suite.meta.ReplicaManager.GetByCollection(ctx, suite.collectionID)
total := 0
for _, r := range replica {
total += r.RWSQNodesCount()
}
return total == 2
}, 6*time.Second, 2*time.Second)
replica = suite.meta.ReplicaManager.GetByCollection(ctx, suite.collectionID)
nodes = typeutil.NewUniqueSet()
for _, r := range replica {
suite.Equal(r.RWSQNodesCount(), 1)
suite.Equal(r.ROSQNodesCount(), 0)
nodes.Insert(r.GetRWSQNodes()...)
}
suite.Equal(nodes.Len(), 2)
}

func (suite *ReplicaObserverSuite) TearDownSuite() {
suite.kv.Close()
suite.observer.Stop()
streamingutil.UnsetStreamingServiceEnabled()
}

func TestReplicaObserver(t *testing.T) {
Expand Down

0 comments on commit ba62e5d

Please sign in to comment.