diff --git a/cluster.go b/cluster.go index 04b175b3..a54a89b1 100644 --- a/cluster.go +++ b/cluster.go @@ -18,6 +18,8 @@ import ( var ErrNoSlot = errors.New("the slot has no redis node") var ErrReplicaOnlyConflict = errors.New("ReplicaOnly conflicts with SendToReplicas option") var ErrInvalidShardsRefreshInterval = errors.New("ShardsRefreshInterval must be greater than or equal to 0") +var ErrReplicaOnlyConflictWithReplicaSelector = errors.New("ReplicaOnly conflicts with ReplicaSelector option") +var ErrSendToReplicasNotSet = errors.New("SendToReplicas must be set when ReplicaSelector is set") type clusterClient struct { pslots [16384]conn @@ -42,6 +44,10 @@ type connrole struct { //replica bool <- this field is removed because a server may have mixed roles at the same time in the future. https://github.com/valkey-io/valkey/issues/1372 } +var replicaOnlySelector = func(_ uint16, replicas []ReplicaInfo) int { + return util.FastRand(len(replicas)) +} + func newClusterClient(opt *ClientOption, connFn connFn, retryer retryHandler) (*clusterClient, error) { client := &clusterClient{ cmd: cmds.NewBuilder(cmds.InitSlot), @@ -56,6 +62,16 @@ func newClusterClient(opt *ClientOption, connFn connFn, retryer retryHandler) (* if opt.ReplicaOnly && opt.SendToReplicas != nil { return nil, ErrReplicaOnlyConflict } + if opt.ReplicaOnly && opt.ReplicaSelector != nil { + return nil, ErrReplicaOnlyConflictWithReplicaSelector + } + if opt.ReplicaSelector != nil && opt.SendToReplicas == nil { + return nil, ErrSendToReplicasNotSet + } + + if opt.SendToReplicas != nil && opt.ReplicaSelector == nil { + opt.ReplicaSelector = replicaOnlySelector + } if opt.SendToReplicas != nil { rOpt := *opt @@ -194,12 +210,12 @@ func (c *clusterClient) _refresh() (err error) { for master, g := range groups { conns[master] = connrole{conn: c.connFn(master, c.opt)} if c.rOpt != nil { - for _, addr := range g.nodes[1:] { - conns[addr] = connrole{conn: c.connFn(addr, c.rOpt)} + for _, nodeInfo := range g.nodes[1:] { + conns[nodeInfo.Addr] = connrole{conn: c.connFn(nodeInfo.Addr, c.rOpt)} } } else { - for _, addr := range g.nodes[1:] { - conns[addr] = connrole{conn: c.connFn(addr, c.opt)} + for _, nodeInfo := range g.nodes[1:] { + conns[nodeInfo.Addr] = connrole{conn: c.connFn(nodeInfo.Addr, c.opt)} } } } @@ -234,18 +250,25 @@ func (c *clusterClient) _refresh() (err error) { nodesCount := len(g.nodes) for _, slot := range g.slots { for i := slot[0]; i <= slot[1]; i++ { - pslots[i] = conns[g.nodes[1+util.FastRand(nodesCount-1)]].conn + pslots[i] = conns[g.nodes[1+util.FastRand(nodesCount-1)].Addr].conn } } - case c.rOpt != nil: // implies c.opt.SendToReplicas != nil + case c.rOpt != nil: if len(rslots) == 0 { // lazy init rslots = make([]conn, 16384) } if len(g.nodes) > 1 { + n := len(g.nodes) - 1 for _, slot := range g.slots { for i := slot[0]; i <= slot[1]; i++ { pslots[i] = conns[master].conn - rslots[i] = conns[g.nodes[1+util.FastRand(len(g.nodes)-1)]].conn + + rIndex := c.opt.ReplicaSelector(uint16(i), g.nodes[1:]) + if rIndex >= 0 && rIndex < n { + rslots[i] = conns[g.nodes[1+rIndex].Addr].conn + } else { + rslots[i] = conns[master].conn + } } } } else { @@ -297,8 +320,10 @@ func (c *clusterClient) nodes() []string { return nodes } +type nodes []ReplicaInfo + type group struct { - nodes []string + nodes nodes slots [][2]int64 } @@ -324,10 +349,10 @@ func parseSlots(slots RedisMessage, defaultAddr string) map[string]group { g, ok := groups[master] if !ok { g.slots = make([][2]int64, 0) - g.nodes = make([]string, 0, len(v.values)-2) + g.nodes = make(nodes, 0, len(v.values)-2) for i := 2; i < len(v.values); i++ { if dst := parseEndpoint(defaultAddr, v.values[i].values[0].string, v.values[i].values[1].integer); dst != "" { - g.nodes = append(g.nodes, dst) + g.nodes = append(g.nodes, ReplicaInfo{Addr: dst}) } } } @@ -345,16 +370,16 @@ func parseShards(shards RedisMessage, defaultAddr string, tls bool) map[string]g m := -1 shard, _ := v.AsMap() slots := shard["slots"].values - nodes := shard["nodes"].values + _nodes := shard["nodes"].values g := group{ - nodes: make([]string, 0, len(nodes)), + nodes: make(nodes, 0, len(_nodes)), slots: make([][2]int64, len(slots)/2), } for i := range g.slots { g.slots[i][0], _ = slots[i*2].AsInt64() g.slots[i][1], _ = slots[i*2+1].AsInt64() } - for _, n := range nodes { + for _, n := range _nodes { dict, _ := n.AsMap() if dict["health"].string != "online" { continue @@ -367,12 +392,12 @@ func parseShards(shards RedisMessage, defaultAddr string, tls bool) map[string]g if dict["role"].string == "master" { m = len(g.nodes) } - g.nodes = append(g.nodes, dst) + g.nodes = append(g.nodes, ReplicaInfo{Addr: dst}) } } if m >= 0 { g.nodes[0], g.nodes[m] = g.nodes[m], g.nodes[0] - groups[g.nodes[0]] = g + groups[g.nodes[0].Addr] = g } } return groups @@ -1078,15 +1103,15 @@ func (c *clusterClient) Dedicate() (DedicatedClient, func()) { func (c *clusterClient) Nodes() map[string]Client { c.mu.RLock() - nodes := make(map[string]Client, len(c.conns)) + _nodes := make(map[string]Client, len(c.conns)) disableCache := c.opt != nil && c.opt.DisableCache for addr, cc := range c.conns { if !cc.hidden { - nodes[addr] = newSingleClientWithConn(cc.conn, c.cmd, c.retry, disableCache, c.retryHandler) + _nodes[addr] = newSingleClientWithConn(cc.conn, c.cmd, c.retry, disableCache, c.retryHandler) } } c.mu.RUnlock() - return nodes + return _nodes } func (c *clusterClient) Close() { diff --git a/cluster_test.go b/cluster_test.go index 9a152cb2..1620b54c 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1188,6 +1188,236 @@ func TestClusterClientInit(t *testing.T) { t.Fatalf("unexpected err %v", err) } }) + + t.Run("Refresh cluster which has only primary node per shard with ReplicaSelector option", func(t *testing.T) { + m := &mockConn{ + DoFn: func(cmd Completed) RedisResult { + if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { + return slotsMultiRespWithoutReplicas + } + return RedisResult{} + }, + } + + client, err := newClusterClient( + &ClientOption{ + InitAddress: []string{"127.0.0.1:0"}, + SendToReplicas: func(cmd Completed) bool { + return true + }, + ReplicaSelector: func(slot uint16, replicas []ReplicaInfo) int { + return 0 + }, + }, + func(dst string, opt *ClientOption) conn { + copiedM := *m + return &copiedM + }, + newRetryer(defaultRetryDelayFn), + ) + if err != nil { + t.Fatalf("unexpected err %v", err) + } + + if client.pslots[0] != client.conns["127.0.0.1:0"].conn { + t.Fatalf("unexpected node assigned to pslot 0") + } + if client.pslots[8192] != client.conns["127.0.0.1:0"].conn { + t.Fatalf("unexpected node assigned to pslot 8192") + } + if client.pslots[8193] != client.conns["127.0.1.1:0"].conn { + t.Fatalf("unexpected node assigned to pslot 8193") + } + if client.pslots[16383] != client.conns["127.0.1.1:0"].conn { + t.Fatalf("unexpected node assigned to pslot 16383") + } + if client.rslots[0] != client.conns["127.0.0.1:0"].conn { + t.Fatalf("unexpected node assigned to rslot 0") + } + if client.rslots[8192] != client.conns["127.0.0.1:0"].conn { + t.Fatalf("unexpected node assigned to rslot 8192") + } + if client.rslots[8193] != client.conns["127.0.1.1:0"].conn { + t.Fatalf("unexpected node assigned to rslot 8193") + } + if client.rslots[16383] != client.conns["127.0.1.1:0"].conn { + t.Fatalf("unexpected node assigned to rslot 16383") + } + }) + + t.Run("Refresh cluster which has multi replicas per shard with ReplicaSelector option. Returned index is within range", func(t *testing.T) { + primaryNodeConn := &mockConn{ + DoFn: func(cmd Completed) RedisResult { + if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { + return slotsMultiRespWithMultiReplicas + } + return RedisResult{ + err: errors.New("unexpected call"), + } + }, + } + replicaNodeConn1 := &mockConn{ + DoFn: func(cmd Completed) RedisResult { + return RedisResult{ + err: errors.New("unexpected call"), + } + }, + } + replicaNodeConn2 := &mockConn{ + DoFn: func(cmd Completed) RedisResult { + return RedisResult{ + err: errors.New("unexpected call"), + } + }, + } + replicaNodeConn3 := &mockConn{ + DoFn: func(cmd Completed) RedisResult { + return RedisResult{ + err: errors.New("unexpected call"), + } + }, + } + + client, err := newClusterClient( + &ClientOption{ + InitAddress: []string{"127.0.0.1:0"}, + SendToReplicas: func(cmd Completed) bool { + return true + }, + ReplicaSelector: func(slot uint16, replicas []ReplicaInfo) int { + return 1 + }, + }, + func(dst string, opt *ClientOption) conn { + switch { + case dst == "127.0.0.2:1" || dst == "127.0.1.2:1": + return replicaNodeConn1 + case dst == "127.0.0.3:2" || dst == "127.0.1.3:2": + return replicaNodeConn2 + case dst == "127.0.0.4:3" || dst == "127.0.1.4:3": + return replicaNodeConn3 + default: + return primaryNodeConn + } + }, + newRetryer(defaultRetryDelayFn), + ) + if err != nil { + t.Fatalf("unexpected err %v", err) + } + + if client.pslots[0] != primaryNodeConn { + t.Fatalf("unexpected node assigned to pslot 0") + } + if client.pslots[8192] != primaryNodeConn { + t.Fatalf("unexpected node assigned to pslot 8192") + } + if client.pslots[8193] != primaryNodeConn { + t.Fatalf("unexpected node assigned to pslot 8193") + } + if client.pslots[16383] != primaryNodeConn { + t.Fatalf("unexpected node assigned to pslot 16383") + } + if client.rslots[0] != replicaNodeConn2 { + t.Fatalf("unexpected node assigned to rslot 0") + } + if client.rslots[8192] != replicaNodeConn2 { + t.Fatalf("unexpected node assigned to rslot 8192") + } + if client.rslots[8193] != replicaNodeConn2 { + t.Fatalf("unexpected node assigned to rslot 8193") + } + if client.rslots[16383] != replicaNodeConn2 { + t.Fatalf("unexpected node assigned to rslot 16383") + } + }) + + t.Run("Refresh cluster which has multi replicas per shard with ReplicaSelector option. Returned index is out of range", func(t *testing.T) { + primaryNodeConn := &mockConn{ + DoFn: func(cmd Completed) RedisResult { + if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { + return slotsMultiRespWithMultiReplicas + } + return RedisResult{ + err: errors.New("unexpected call"), + } + }, + } + replicaNodeConn1 := &mockConn{ + DoFn: func(cmd Completed) RedisResult { + return RedisResult{ + err: errors.New("unexpected call"), + } + }, + } + replicaNodeConn2 := &mockConn{ + DoFn: func(cmd Completed) RedisResult { + return RedisResult{ + err: errors.New("unexpected call"), + } + }, + } + replicaNodeConn3 := &mockConn{ + DoFn: func(cmd Completed) RedisResult { + return RedisResult{ + err: errors.New("unexpected call"), + } + }, + } + + client, err := newClusterClient( + &ClientOption{ + InitAddress: []string{"127.0.0.1:0"}, + SendToReplicas: func(cmd Completed) bool { + return true + }, + ReplicaSelector: func(slot uint16, replicas []ReplicaInfo) int { + return -1 + }, + }, + func(dst string, opt *ClientOption) conn { + switch { + case dst == "127.0.0.2:1" || dst == "127.0.1.2:1": + return replicaNodeConn1 + case dst == "127.0.0.3:2" || dst == "127.0.1.3:2": + return replicaNodeConn2 + case dst == "127.0.0.4:3" || dst == "127.0.1.4:3": + return replicaNodeConn3 + default: + return primaryNodeConn + } + }, + newRetryer(defaultRetryDelayFn), + ) + if err != nil { + t.Fatalf("unexpected err %v", err) + } + + if client.pslots[0] != primaryNodeConn { + t.Fatalf("unexpected node assigned to pslot 0") + } + if client.pslots[8192] != primaryNodeConn { + t.Fatalf("unexpected node assigned to pslot 8192") + } + if client.pslots[8193] != primaryNodeConn { + t.Fatalf("unexpected node assigned to pslot 8193") + } + if client.pslots[16383] != primaryNodeConn { + t.Fatalf("unexpected node assigned to pslot 16383") + } + if client.rslots[0] != primaryNodeConn { + t.Fatalf("unexpected node assigned to rslot 0") + } + if client.rslots[8192] != primaryNodeConn { + t.Fatalf("unexpected node assigned to rslot 8192") + } + if client.rslots[8193] != primaryNodeConn { + t.Fatalf("unexpected node assigned to rslot 8193") + } + if client.rslots[16383] != primaryNodeConn { + t.Fatalf("unexpected node assigned to rslot 16383") + } + }) } //gocyclo:ignore @@ -4818,13 +5048,15 @@ func TestClusterShardsParsing(t *testing.T) { t.Fatalf("unexpected result %v", result) } for _, val := range result { - nodes := val.nodes - sort.Strings(nodes) - if len(nodes) != 3 || - nodes[0] != "127.0.1.1:1" || - nodes[1] != "127.0.2.1:2" || - nodes[2] != "127.0.3.1:3" { - t.Fatalf("unexpected nodes %v", nodes) + _nodes := val.nodes + sort.Slice(_nodes, func(i, j int) bool { + return _nodes[i].Addr < _nodes[j].Addr + }) + if len(_nodes) != 3 || + _nodes[0].Addr != "127.0.1.1:1" || + _nodes[1].Addr != "127.0.2.1:2" || + _nodes[2].Addr != "127.0.3.1:3" { + t.Fatalf("unexpected nodes %v", _nodes) } } @@ -4833,13 +5065,15 @@ func TestClusterShardsParsing(t *testing.T) { t.Fatalf("unexpected result %v", result) } for _, val := range result { - nodes := val.nodes - sort.Strings(nodes) - if len(nodes) != 3 || - nodes[0] != "127.0.1.1:0" || - nodes[1] != "127.0.2.1:0" || - nodes[2] != "127.0.3.1:3" { - t.Fatalf("unexpected nodes %v", nodes) + _nodes := val.nodes + sort.Slice(_nodes, func(i, j int) bool { + return _nodes[i].Addr < _nodes[j].Addr + }) + if len(_nodes) != 3 || + _nodes[0].Addr != "127.0.1.1:0" || + _nodes[1].Addr != "127.0.2.1:0" || + _nodes[2].Addr != "127.0.3.1:3" { + t.Fatalf("unexpected nodes %v", _nodes) } } }) @@ -4850,7 +5084,7 @@ func TestClusterShardsParsing(t *testing.T) { t.Fatalf("unexpected result %v", result) } for master, group := range result { - if len(group.nodes) == 0 || group.nodes[0] != master { + if len(group.nodes) == 0 || group.nodes[0].Addr != master { t.Fatalf("unexpected first node %v", group) } } @@ -5306,5 +5540,1103 @@ func TestClusterClientMovedRetry(t *testing.T) { t.Fatalf("unexpected response %v %v", v, err) } }) +} + +//gocyclo:ignore +func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + + primaryNodeConn := &mockConn{ + DoOverride: map[string]func(cmd Completed) RedisResult{ + "CLUSTER SLOTS": func(cmd Completed) RedisResult { + return slotsMultiResp + }, + "INFO": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "INFO"}, nil) + }, + "SET Do V": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "SET Do V"}, nil) + }, + "SET K2{a} V2{a}": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "SET K2{a} V2{a}"}, nil) + }, + }, + DoMultiFn: func(multi ...Completed) *redisresults { + resps := make([]RedisResult, len(multi)) + for i, cmd := range multi { + if strings.HasPrefix(strings.Join(cmd.Commands(), " "), "SET K1") { + resps[i] = newResult(RedisMessage{typ: '+', string: strings.Join(cmd.Commands(), " ")}, nil) + continue + } + if strings.HasPrefix(strings.Join(cmd.Commands(), " "), "SET K2") { + resps[i] = newResult(RedisMessage{typ: '+', string: strings.Join(cmd.Commands(), " ")}, nil) + continue + } + if strings.HasPrefix(strings.Join(cmd.Commands(), " "), "MULTI") { + resps[i] = newResult(RedisMessage{typ: '+', string: "MULTI"}, nil) + continue + } + if strings.HasPrefix(strings.Join(cmd.Commands(), " "), "EXEC") { + resps[i] = newResult(RedisMessage{typ: '+', string: "EXEC"}, nil) + continue + } + return &redisresults{ + s: []RedisResult{}, + } + } + return &redisresults{s: resps} + }, + } + replicaNodeConn := &mockConn{ + DoOverride: map[string]func(cmd Completed) RedisResult{ + "GET Do": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET Do"}, nil) + }, + "GET K1{a}": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET K1{a}"}, nil) + }, + "GET K2{a}": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET K2{a}"}, nil) + }, + }, + DoMultiFn: func(multi ...Completed) *redisresults { + resps := make([]RedisResult, len(multi)) + for i, cmd := range multi { + if strings.HasPrefix(strings.Join(cmd.Commands(), " "), "GET K1") { + resps[i] = newResult(RedisMessage{typ: '+', string: strings.Join(cmd.Commands(), " ")}, nil) + continue + } + + return &redisresults{ + s: []RedisResult{}, + } + } + return &redisresults{s: resps} + }, + DoCacheOverride: map[string]func(cmd Cacheable, ttl time.Duration) RedisResult{ + "GET DoCache": func(cmd Cacheable, ttl time.Duration) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET DoCache"}, nil) + }, + "GET K1{a}": func(cmd Cacheable, ttl time.Duration) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET K1{a}"}, nil) + }, + "GET K2{a}": func(cmd Cacheable, ttl time.Duration) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET K2{a}"}, nil) + }, + }, + DoMultiCacheFn: func(multi ...CacheableTTL) *redisresults { + resps := make([]RedisResult, len(multi)) + for i, cmd := range multi { + if strings.HasPrefix(strings.Join(cmd.Cmd.Commands(), " "), "GET K1") { + resps[i] = newResult(RedisMessage{typ: '+', string: strings.Join(cmd.Cmd.Commands(), " ")}, nil) + continue + } + + return &redisresults{ + s: []RedisResult{}, + } + } + return &redisresults{s: resps} + }, + } + + client, err := newClusterClient( + &ClientOption{ + InitAddress: []string{"127.0.0.1:0"}, + SendToReplicas: func(cmd Completed) bool { + return cmd.IsReadOnly() + }, + ReplicaSelector: func(slot uint16, replicas []ReplicaInfo) int { + return 0 + }, + }, + func(dst string, opt *ClientOption) conn { + if dst == "127.0.0.1:0" || dst == "127.0.2.1:0" { // primary node + return primaryNodeConn + } else { // replica node + return replicaNodeConn + } + }, + newRetryer(defaultRetryDelayFn), + ) + if err != nil { + t.Fatalf("unexpected err %v", err) + } + + t.Run("Do read operation", func(t *testing.T) { + c := client.B().Get().Key("Do").Build() + if v, err := client.Do(context.Background(), c).ToString(); err != nil || v != "GET Do" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("Do write operation", func(t *testing.T) { + c := client.B().Set().Key("Do").Value("V").Build() + if v, err := client.Do(context.Background(), c).ToString(); err != nil || v != "SET Do V" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti Single Slot All Read Operations", func(t *testing.T) { + c1 := client.B().Get().Key("K1{a}").Build() + c2 := client.B().Get().Key("K2{a}").Build() + resps := client.DoMulti(context.Background(), c1, c2) + if v, err := resps[0].ToString(); err != nil || v != "GET K1{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v, err := resps[1].ToString(); err != nil || v != "GET K2{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti Single Slot Read Operation And Write Operation", func(t *testing.T) { + c1 := client.B().Get().Key("K1{a}").Build() + c2 := client.B().Set().Key("K2{a}").Value("V2{a}").Build() + resps := client.DoMulti(context.Background(), c1, c2) + if v, err := resps[0].ToString(); err != nil || v != "GET K1{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v, err := resps[1].ToString(); err != nil || v != "SET K2{a} V2{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti Single Slot Operations + Init Slot", func(t *testing.T) { + c1 := client.B().Multi().Build() + c2 := client.B().Set().Key("K1{a}").Value("V1{a}").Build() + c3 := client.B().Set().Key("K2{a}").Value("V2{a}").Build() + c4 := client.B().Exec().Build() + resps := client.DoMulti(context.Background(), c1, c2, c3, c4) + if v, err := resps[0].ToString(); err != nil || v != "MULTI" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v, err := resps[1].ToString(); err != nil || v != "SET K1{a} V1{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v, err := resps[2].ToString(); err != nil || v != "SET K2{a} V2{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v, err := resps[3].ToString(); err != nil || v != "EXEC" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti Cross Slot + Init Slot", func(t *testing.T) { + defer func() { + if err := recover(); err != panicMixCxSlot { + t.Errorf("DoMulti should panic if Cross Slot + Init Slot") + } + }() + c1 := client.B().Get().Key("K1{a}").Build() + c2 := client.B().Get().Key("K1{b}").Build() + c3 := client.B().Info().Build() + client.DoMulti(context.Background(), c1, c2, c3) + }) + + t.Run("DoMulti Multi Slot All Read Operations", func(t *testing.T) { + multi := make([]Completed, 500) + for i := 0; i < len(multi); i++ { + multi[i] = client.B().Get().Key(fmt.Sprintf("K1{%d}", i)).Build() + } + resps := client.DoMulti(context.Background(), multi...) + for i := 0; i < len(multi); i++ { + if v, err := resps[i].ToString(); err != nil || v != fmt.Sprintf("GET K1{%d}", i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + }) + t.Run("DoMulti Multi Slot Read & Write Operations", func(t *testing.T) { + multi := make([]Completed, 500) + for i := 0; i < len(multi); i++ { + if i%2 == 0 { + multi[i] = client.B().Get().Key(fmt.Sprintf("K1{%d}", i)).Build() + } else { + multi[i] = client.B().Set().Key(fmt.Sprintf("K2{%d}", i)).Value(fmt.Sprintf("V2{%d}", i)).Build() + } + } + resps := client.DoMulti(context.Background(), multi...) + for i := 0; i < len(multi); i++ { + if i%2 == 0 { + if v, err := resps[i].ToString(); err != nil || v != fmt.Sprintf("GET K1{%d}", i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } else { + if v, err := resps[i].ToString(); err != nil || v != fmt.Sprintf("SET K2{%d} V2{%d}", i, i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + } + }) + + t.Run("DoCache Operation", func(t *testing.T) { + c := client.B().Get().Key("DoCache").Cache() + if v, err := client.DoCache(context.Background(), c, 100).ToString(); err != nil || v != "GET DoCache" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMultiCache Single Slot", func(t *testing.T) { + c1 := client.B().Get().Key("K1{a}").Cache() + c2 := client.B().Get().Key("K2{a}").Cache() + resps := client.DoMultiCache(context.Background(), CT(c1, time.Second), CT(c2, time.Second)) + if v, err := resps[0].ToString(); err != nil || v != "GET K1{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v, err := resps[1].ToString(); err != nil || v != "GET K2{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMultiCache Multi Slot", func(t *testing.T) { + multi := make([]CacheableTTL, 500) + for i := 0; i < len(multi); i++ { + multi[i] = CT(client.B().Get().Key(fmt.Sprintf("K1{%d}", i)).Cache(), time.Second) + } + resps := client.DoMultiCache(context.Background(), multi...) + for i := 0; i < len(multi); i++ { + if v, err := resps[i].ToString(); err != nil || v != fmt.Sprintf("GET K1{%d}", i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + }) + + t.Run("Receive", func(t *testing.T) { + c := client.B().Subscribe().Channel("ch").Build() + hdl := func(message PubSubMessage) {} + primaryNodeConn.ReceiveFn = func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + if !reflect.DeepEqual(subscribe.Commands(), c.Commands()) { + t.Fatalf("unexpected command %v", subscribe) + } + return nil + } + replicaNodeConn.ReceiveFn = func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + if !reflect.DeepEqual(subscribe.Commands(), c.Commands()) { + t.Fatalf("unexpected command %v", subscribe) + } + return nil + } + if err := client.Receive(context.Background(), c, hdl); err != nil { + t.Fatalf("unexpected response %v", err) + } + }) + + t.Run("Receive Redis Err", func(t *testing.T) { + c := client.B().Ssubscribe().Channel("ch").Build() + e := &RedisError{} + primaryNodeConn.ReceiveFn = func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return e + } + replicaNodeConn.ReceiveFn = func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return e + } + if err := client.Receive(context.Background(), c, func(message PubSubMessage) {}); err != e { + t.Fatalf("unexpected response %v", err) + } + }) + + t.Run("Dedicated Cross Slot Err", func(t *testing.T) { + defer func() { + if err := recover(); err != panicMsgCxSlot { + t.Errorf("Dedicated should panic if cross slots is used") + } + }() + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + client.Dedicated(func(c DedicatedClient) error { + c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() + return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() + }) + }) + + t.Run("Dedicated Cross Slot Err Multi", func(t *testing.T) { + defer func() { + if err := recover(); err != panicMsgCxSlot { + t.Errorf("Dedicated should panic if cross slots is used") + } + }() + primaryNodeConn.AcquireFn = func() wire { + return &mockWire{ + DoMultiFn: func(multi ...Completed) *redisresults { + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '*', values: []RedisMessage{{typ: '+', string: "a"}}}, nil), + }} + }, + } + } + client.Dedicated(func(c DedicatedClient) (err error) { + c.DoMulti( + context.Background(), + c.B().Multi().Build(), + c.B().Get().Key("a").Build(), + c.B().Exec().Build(), + ) + c.DoMulti( + context.Background(), + c.B().Multi().Build(), + c.B().Get().Key("b").Build(), + c.B().Exec().Build(), + ) + return nil + }) + }) + + t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + err := client.Dedicated(func(c DedicatedClient) (err error) { + defer func() { + err = errors.New(recover().(string)) + }() + c.DoMulti( + context.Background(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("b").Build(), + ) + return nil + }) + if err == nil || err.Error() != panicMsgCxSlot { + t.Errorf("Multi should panic if cross slots is used") + } + }) + + t.Run("Dedicated Receive Redis Err", func(t *testing.T) { + e := &RedisError{} + w := &mockWire{ + ReceiveFn: func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return e + }, + } + primaryNodeConn.AcquireFn = func() wire { + return w + } + if err := client.Dedicated(func(c DedicatedClient) error { + return c.Receive(context.Background(), c.B().Subscribe().Channel("a").Build(), func(msg PubSubMessage) {}) + }); err != e { + t.Fatalf("unexpected err %v", err) + } + }) + + t.Run("Dedicated", func(t *testing.T) { + closed := false + w := &mockWire{ + DoFn: func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "Delegate"}, nil) + }, + DoMultiFn: func(cmd ...Completed) *redisresults { + if len(cmd) == 4 { + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '*', values: []RedisMessage{ + {typ: '+', string: "Delegate0"}, + {typ: '+', string: "Delegate1"}, + }}, nil), + }} + } + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "Delegate0"}, nil), + newResult(RedisMessage{typ: '+', string: "Delegate1"}, nil), + }} + }, + ReceiveFn: func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return ErrClosing + }, + SetPubSubHooksFn: func(hooks PubSubHooks) <-chan error { + ch := make(chan error, 1) + ch <- ErrClosing + close(ch) + return ch + }, + ErrorFn: func() error { + return ErrClosing + }, + CloseFn: func() { + closed = true + }, + } + primaryNodeConn.AcquireFn = func() wire { + return w + } + stored := false + primaryNodeConn.StoreFn = func(ww wire) { + if ww != w { + t.Fatalf("received unexpected wire %v", ww) + } + stored = true + } + if err := client.Dedicated(func(c DedicatedClient) error { + ch := c.SetPubSubHooks(PubSubHooks{OnMessage: func(m PubSubMessage) {}}) + if v, err := c.Do(context.Background(), c.B().Get().Key("a").Build()).ToString(); err != nil || v != "Delegate" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v := c.DoMulti(context.Background()); len(v) != 0 { + t.Fatalf("received unexpected response %v", v) + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Info().Build(), + c.B().Info().Build(), + ) { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("a").Build(), + ) { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Multi().Build(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("a").Build(), + c.B().Exec().Build(), + )[3].val.values { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + if err := c.Receive(context.Background(), c.B().Ssubscribe().Channel("a").Build(), func(msg PubSubMessage) {}); err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + if err := <-ch; err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + if err := <-c.SetPubSubHooks(PubSubHooks{OnMessage: func(m PubSubMessage) {}}); err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + c.Close() + return nil + }); err != nil { + t.Fatalf("unexpected err %v", err) + } + if !stored { + t.Fatalf("Dedicated desn't put back the wire") + } + if !closed { + t.Fatalf("Dedicated desn't delegate Close") + } + }) + + t.Run("Dedicate", func(t *testing.T) { + closed := false + w := &mockWire{ + DoFn: func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "Delegate"}, nil) + }, + DoMultiFn: func(cmd ...Completed) *redisresults { + if len(cmd) == 4 { + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '*', values: []RedisMessage{ + {typ: '+', string: "Delegate0"}, + {typ: '+', string: "Delegate1"}, + }}, nil), + }} + } + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "Delegate0"}, nil), + newResult(RedisMessage{typ: '+', string: "Delegate1"}, nil), + }} + }, + ReceiveFn: func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return ErrClosing + }, + SetPubSubHooksFn: func(hooks PubSubHooks) <-chan error { + ch := make(chan error, 1) + ch <- ErrClosing + close(ch) + return ch + }, + ErrorFn: func() error { + return ErrClosing + }, + CloseFn: func() { + closed = true + }, + } + primaryNodeConn.AcquireFn = func() wire { + return w + } + stored := false + primaryNodeConn.StoreFn = func(ww wire) { + if ww != w { + t.Fatalf("received unexpected wire %v", ww) + } + stored = true + } + c, cancel := client.Dedicate() + ch := c.SetPubSubHooks(PubSubHooks{OnMessage: func(m PubSubMessage) {}}) + if v, err := c.Do(context.Background(), c.B().Get().Key("a").Build()).ToString(); err != nil || v != "Delegate" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v := c.DoMulti(context.Background()); len(v) != 0 { + t.Fatalf("received unexpected response %v", v) + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Info().Build(), + c.B().Info().Build(), + ) { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("a").Build(), + ) { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Multi().Build(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("a").Build(), + c.B().Exec().Build(), + )[3].val.values { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + if err := c.Receive(context.Background(), c.B().Ssubscribe().Channel("a").Build(), func(msg PubSubMessage) {}); err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + if err := <-ch; err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + if err := <-c.SetPubSubHooks(PubSubHooks{OnMessage: func(m PubSubMessage) {}}); err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + c.Close() + cancel() + + if !stored { + t.Fatalf("Dedicated desn't put back the wire") + } + if !closed { + t.Fatalf("Dedicated desn't delegate Close") + } + }) +} + +//gocyclo:ignore +func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + + primaryNodeConn := &mockConn{ + DoOverride: map[string]func(cmd Completed) RedisResult{ + "CLUSTER SLOTS": func(cmd Completed) RedisResult { + return slotsMultiResp + }, + "GET Do": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET Do"}, nil) + }, + "GET K1{a}": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET K1{a}"}, nil) + }, + "GET K2{a}": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET K2{a}"}, nil) + }, + "INFO": func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "INFO"}, nil) + }, + }, + DoMultiFn: func(multi ...Completed) *redisresults { + resps := make([]RedisResult, len(multi)) + for i, cmd := range multi { + resps[i] = newResult(RedisMessage{typ: '+', string: strings.Join(cmd.Commands(), " ")}, nil) + } + return &redisresults{s: resps} + }, + DoCacheOverride: map[string]func(cmd Cacheable, ttl time.Duration) RedisResult{ + "GET DoCache": func(cmd Cacheable, ttl time.Duration) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET DoCache"}, nil) + }, + "GET K1{a}": func(cmd Cacheable, ttl time.Duration) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET K1{a}"}, nil) + }, + "GET K2{a}": func(cmd Cacheable, ttl time.Duration) RedisResult { + return newResult(RedisMessage{typ: '+', string: "GET K2{a}"}, nil) + }, + }, + DoMultiCacheFn: func(multi ...CacheableTTL) *redisresults { + resps := make([]RedisResult, len(multi)) + for i, cmd := range multi { + resps[i] = newResult(RedisMessage{typ: '+', string: strings.Join(cmd.Cmd.Commands(), " ")}, nil) + } + return &redisresults{s: resps} + }, + } + replicaNodeConn := &mockConn{} + + client, err := newClusterClient( + &ClientOption{ + InitAddress: []string{"127.0.0.1:0"}, + SendToReplicas: func(cmd Completed) bool { + return true + }, + ReplicaSelector: func(slot uint16, replicas []ReplicaInfo) int { + return -1 + }, + }, + func(dst string, opt *ClientOption) conn { + if dst == "127.0.0.1:0" || dst == "127.0.2.1:0" { // primary nodes + return primaryNodeConn + } else { // replica nodes + return replicaNodeConn + } + }, + newRetryer(defaultRetryDelayFn), + ) + if err != nil { + t.Fatalf("unexpected err %v", err) + } + + t.Run("Do with no slot", func(t *testing.T) { + c := client.B().Info().Build() + if v, err := client.Do(context.Background(), c).ToString(); err != nil || v != "INFO" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("Do", func(t *testing.T) { + c := client.B().Get().Key("Do").Build() + if v, err := client.Do(context.Background(), c).ToString(); err != nil || v != "GET Do" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti Single Slot", func(t *testing.T) { + c1 := client.B().Get().Key("K1{a}").Build() + c2 := client.B().Get().Key("K2{a}").Build() + resps := client.DoMulti(context.Background(), c1, c2) + if v, err := resps[0].ToString(); err != nil || v != "GET K1{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v, err := resps[1].ToString(); err != nil || v != "GET K2{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti Single Slot + Init Slot", func(t *testing.T) { + c1 := client.B().Get().Key("K1{a}").Build() + c2 := client.B().Info().Build() + resps := client.DoMulti(context.Background(), c1, c2) + if v, err := resps[0].ToString(); err != nil || v != "GET K1{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v, err := resps[1].ToString(); err != nil || v != "INFO" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti Cross Slot + Init Slot", func(t *testing.T) { + defer func() { + if err := recover(); err != panicMixCxSlot { + t.Errorf("DoMulti should panic if Cross Slot + Init Slot") + } + }() + c1 := client.B().Get().Key("K1{a}").Build() + c2 := client.B().Get().Key("K1{b}").Build() + c3 := client.B().Info().Build() + client.DoMulti(context.Background(), c1, c2, c3) + }) + + t.Run("DoMulti Multi Slot", func(t *testing.T) { + multi := make([]Completed, 500) + for i := 0; i < len(multi); i++ { + multi[i] = client.B().Get().Key(fmt.Sprintf("K1{%d}", i)).Build() + } + resps := client.DoMulti(context.Background(), multi...) + for i := 0; i < len(multi); i++ { + if v, err := resps[i].ToString(); err != nil || v != fmt.Sprintf("GET K1{%d}", i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + }) + + t.Run("DoCache", func(t *testing.T) { + c := client.B().Get().Key("DoCache").Cache() + if v, err := client.DoCache(context.Background(), c, 100).ToString(); err != nil || v != "GET DoCache" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMultiCache Single Slot", func(t *testing.T) { + c1 := client.B().Get().Key("K1{a}").Cache() + c2 := client.B().Get().Key("K2{a}").Cache() + resps := client.DoMultiCache(context.Background(), CT(c1, time.Second), CT(c2, time.Second)) + if v, err := resps[0].ToString(); err != nil || v != "GET K1{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v, err := resps[1].ToString(); err != nil || v != "GET K2{a}" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMultiCache Multi Slot", func(t *testing.T) { + multi := make([]CacheableTTL, 500) + for i := 0; i < len(multi); i++ { + multi[i] = CT(client.B().Get().Key(fmt.Sprintf("K1{%d}", i)).Cache(), time.Second) + } + resps := client.DoMultiCache(context.Background(), multi...) + for i := 0; i < len(multi); i++ { + if v, err := resps[i].ToString(); err != nil || v != fmt.Sprintf("GET K1{%d}", i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + }) + + t.Run("Receive", func(t *testing.T) { + c := client.B().Subscribe().Channel("ch").Build() + hdl := func(message PubSubMessage) {} + primaryNodeConn.ReceiveFn = func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + if !reflect.DeepEqual(subscribe.Commands(), c.Commands()) { + t.Fatalf("unexpected command %v", subscribe) + } + return nil + } + replicaNodeConn.ReceiveFn = func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + if !reflect.DeepEqual(subscribe.Commands(), c.Commands()) { + t.Fatalf("unexpected command %v", subscribe) + } + return nil + } + + if err := client.Receive(context.Background(), c, hdl); err != nil { + t.Fatalf("unexpected response %v", err) + } + }) + + t.Run("Receive Redis Err", func(t *testing.T) { + c := client.B().Subscribe().Channel("ch").Build() + e := &RedisError{} + primaryNodeConn.ReceiveFn = func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return e + } + replicaNodeConn.ReceiveFn = func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return e + } + if err := client.Receive(context.Background(), c, func(message PubSubMessage) {}); err != e { + t.Fatalf("unexpected response %v", err) + } + }) + + t.Run("Dedicated Err", func(t *testing.T) { + v := errors.New("fn err") + if err := client.Dedicated(func(client DedicatedClient) error { + return v + }); err != v { + t.Fatalf("unexpected err %v", err) + } + }) + + t.Run("Dedicated Cross Slot Err", func(t *testing.T) { + defer func() { + if err := recover(); err != panicMsgCxSlot { + t.Errorf("Dedicated should panic if cross slots is used") + } + }() + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + client.Dedicated(func(c DedicatedClient) error { + c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() + return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() + }) + }) + + t.Run("Dedicated Cross Slot Err Multi", func(t *testing.T) { + defer func() { + if err := recover(); err != panicMsgCxSlot { + t.Errorf("Dedicated should panic if cross slots is used") + } + }() + primaryNodeConn.AcquireFn = func() wire { + return &mockWire{ + DoMultiFn: func(multi ...Completed) *redisresults { + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '*', values: []RedisMessage{{typ: '+', string: "a"}}}, nil), + }} + }, + } + } + client.Dedicated(func(c DedicatedClient) (err error) { + c.DoMulti( + context.Background(), + c.B().Multi().Build(), + c.B().Get().Key("a").Build(), + c.B().Exec().Build(), + ) + c.DoMulti( + context.Background(), + c.B().Multi().Build(), + c.B().Get().Key("b").Build(), + c.B().Exec().Build(), + ) + return nil + }) + }) + + t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + err := client.Dedicated(func(c DedicatedClient) (err error) { + defer func() { + err = errors.New(recover().(string)) + }() + c.DoMulti( + context.Background(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("b").Build(), + ) + return nil + }) + if err == nil || err.Error() != panicMsgCxSlot { + t.Errorf("Multi should panic if cross slots is used") + } + }) + + t.Run("Dedicated Receive Redis Err", func(t *testing.T) { + e := &RedisError{} + w := &mockWire{ + ReceiveFn: func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return e + }, + } + primaryNodeConn.AcquireFn = func() wire { + return w + } + if err := client.Dedicated(func(c DedicatedClient) error { + return c.Receive(context.Background(), c.B().Ssubscribe().Channel("a").Build(), func(msg PubSubMessage) {}) + }); err != e { + t.Fatalf("unexpected err %v", err) + } + }) + + t.Run("Dedicated", func(t *testing.T) { + closed := false + w := &mockWire{ + DoFn: func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "Delegate"}, nil) + }, + DoMultiFn: func(cmd ...Completed) *redisresults { + if len(cmd) == 4 { + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '*', values: []RedisMessage{ + {typ: '+', string: "Delegate0"}, + {typ: '+', string: "Delegate1"}, + }}, nil), + }} + } + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "Delegate0"}, nil), + newResult(RedisMessage{typ: '+', string: "Delegate1"}, nil), + }} + }, + ReceiveFn: func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return ErrClosing + }, + SetPubSubHooksFn: func(hooks PubSubHooks) <-chan error { + ch := make(chan error, 1) + ch <- ErrClosing + close(ch) + return ch + }, + ErrorFn: func() error { + return ErrClosing + }, + CloseFn: func() { + closed = true + }, + } + primaryNodeConn.AcquireFn = func() wire { + return w + } + stored := false + primaryNodeConn.StoreFn = func(ww wire) { + if ww != w { + t.Fatalf("received unexpected wire %v", ww) + } + stored = true + } + if err := client.Dedicated(func(c DedicatedClient) error { + ch := c.SetPubSubHooks(PubSubHooks{OnMessage: func(m PubSubMessage) {}}) + if v, err := c.Do(context.Background(), c.B().Get().Key("a").Build()).ToString(); err != nil || v != "Delegate" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v := c.DoMulti(context.Background()); len(v) != 0 { + t.Fatalf("received unexpected response %v", v) + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Info().Build(), + c.B().Info().Build(), + ) { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("a").Build(), + ) { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Multi().Build(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("a").Build(), + c.B().Exec().Build(), + )[3].val.values { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + if err := c.Receive(context.Background(), c.B().Ssubscribe().Channel("a").Build(), func(msg PubSubMessage) {}); err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + if err := <-ch; err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + if err := <-c.SetPubSubHooks(PubSubHooks{OnMessage: func(m PubSubMessage) {}}); err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + c.Close() + return nil + }); err != nil { + t.Fatalf("unexpected err %v", err) + } + if !stored { + t.Fatalf("Dedicated desn't put back the wire") + } + if !closed { + t.Fatalf("Dedicated desn't delegate Close") + } + }) + + t.Run("Dedicate", func(t *testing.T) { + closed := false + w := &mockWire{ + DoFn: func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "Delegate"}, nil) + }, + DoMultiFn: func(cmd ...Completed) *redisresults { + if len(cmd) == 4 { + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '+', string: "OK"}, nil), + newResult(RedisMessage{typ: '*', values: []RedisMessage{ + {typ: '+', string: "Delegate0"}, + {typ: '+', string: "Delegate1"}, + }}, nil), + }} + } + return &redisresults{s: []RedisResult{ + newResult(RedisMessage{typ: '+', string: "Delegate0"}, nil), + newResult(RedisMessage{typ: '+', string: "Delegate1"}, nil), + }} + }, + ReceiveFn: func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { + return ErrClosing + }, + SetPubSubHooksFn: func(hooks PubSubHooks) <-chan error { + ch := make(chan error, 1) + ch <- ErrClosing + close(ch) + return ch + }, + ErrorFn: func() error { + return ErrClosing + }, + CloseFn: func() { + closed = true + }, + } + primaryNodeConn.AcquireFn = func() wire { + return w + } + stored := false + primaryNodeConn.StoreFn = func(ww wire) { + if ww != w { + t.Fatalf("received unexpected wire %v", ww) + } + stored = true + } + c, cancel := client.Dedicate() + ch := c.SetPubSubHooks(PubSubHooks{OnMessage: func(m PubSubMessage) {}}) + if v, err := c.Do(context.Background(), c.B().Get().Key("a").Build()).ToString(); err != nil || v != "Delegate" { + t.Fatalf("unexpected response %v %v", v, err) + } + if v := c.DoMulti(context.Background()); len(v) != 0 { + t.Fatalf("received unexpected response %v", v) + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Info().Build(), + c.B().Info().Build(), + ) { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("a").Build(), + ) { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + for i, resp := range c.DoMulti( + context.Background(), + c.B().Multi().Build(), + c.B().Get().Key("a").Build(), + c.B().Get().Key("a").Build(), + c.B().Exec().Build(), + )[3].val.values { + if v, err := resp.ToString(); err != nil || v != "Delegate"+strconv.Itoa(i) { + t.Fatalf("unexpected response %v %v", v, err) + } + } + if err := c.Receive(context.Background(), c.B().Ssubscribe().Channel("a").Build(), func(msg PubSubMessage) {}); err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + if err := <-ch; err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + if err := <-c.SetPubSubHooks(PubSubHooks{OnMessage: func(m PubSubMessage) {}}); err != ErrClosing { + t.Fatalf("unexpected ret %v", err) + } + c.Close() + cancel() + + if !stored { + t.Fatalf("Dedicated desn't put back the wire") + } + if !closed { + t.Fatalf("Dedicated desn't delegate Close") + } + }) } diff --git a/rueidis.go b/rueidis.go index 2c6c9ace..bfe56648 100644 --- a/rueidis.go +++ b/rueidis.go @@ -208,6 +208,18 @@ type ClientOption struct { // ClusterOption is the options for the redis cluster client. ClusterOption ClusterOption + + // ReplicaSelector selects a replica node when `SendToReplicas` returns true. + // If the function is set, the client will send selected command to the replica node. + // Returned value is the index of the replica node in the replicas slice. + // If the returned value is out of range, the primary node will be selected. + // If primary node does not have any replica, the primary node will be selected + // and function will not be called. + // Currently only used for cluster client. + // Each ReplicaInfo must not be modified. + // NOTE: This function can't be used with ReplicaOnly option. + // NOTE: This function must be used with SendToReplicas function. + ReplicaSelector func(slot uint16, replicas []ReplicaInfo) int } // SentinelOption contains MasterSet, @@ -234,6 +246,11 @@ type ClusterOption struct { ShardsRefreshInterval time.Duration } +// ReplicaInfo is the information of a replica node in a redis cluster. +type ReplicaInfo struct { + Addr string +} + // Client is the redis client interface for both single redis instance and redis cluster. It should be created from the NewClient() type Client interface { CoreClient diff --git a/rueidis_test.go b/rueidis_test.go index e644b232..4da05a95 100644 --- a/rueidis_test.go +++ b/rueidis_test.go @@ -156,6 +156,53 @@ func TestNewClusterClientError(t *testing.T) { t.Errorf("unexpected error %v", err) } }) + + t.Run("replica only and replica selector option conflict", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + _, port, _ := net.SplitHostPort(ln.Addr().String()) + client, err := NewClient(ClientOption{ + InitAddress: []string{"127.0.0.1:" + port}, + ReplicaOnly: true, + ReplicaSelector: func(slot uint16, replicas []ReplicaInfo) int { + return 0 + }, + }) + if client != nil || err == nil { + t.Errorf("unexpected return %v %v", client, err) + } + + if !strings.Contains(err.Error(), ErrReplicaOnlyConflictWithReplicaSelector.Error()) { + t.Errorf("unexpected error %v", err) + } + }) + + t.Run("send to replicas should be set when replica selector is set", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + _, port, _ := net.SplitHostPort(ln.Addr().String()) + client, err := NewClient(ClientOption{ + InitAddress: []string{"127.0.0.1:" + port}, + ReplicaSelector: func(slot uint16, replicas []ReplicaInfo) int { + return 0 + }, + }) + if client != nil || err == nil { + t.Errorf("unexpected return %v %v", client, err) + } + + if !strings.Contains(err.Error(), ErrSendToReplicasNotSet.Error()) { + t.Errorf("unexpected error %v", err) + } + }) } func TestFallBackSingleClient(t *testing.T) {