diff --git a/go.mod b/go.mod index bd80767..2729b12 100644 --- a/go.mod +++ b/go.mod @@ -10,14 +10,17 @@ module github.com/codecrafters-io/redis-starter-go go 1.22 +require ( + github.com/redis/go-redis/v9 v9.6.0 + github.com/test-go/testify v1.1.4 + github.com/zhuyie/golzf v0.0.0-20161112031142-8387b0307ade +) + require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/redis/go-redis/v9 v9.5.3 // indirect github.com/stretchr/testify v1.9.0 // indirect - github.com/test-go/testify v1.1.4 // indirect - github.com/zhuyie/golzf v0.0.0-20161112031142-8387b0307ade // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index caf1544..ef0c83e 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -6,14 +10,15 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/go-redis/v9 v9.5.3 h1:fOAp1/uJG+ZtcITgZOfYFmTKPE7n4Vclj1wZFgRciUU= -github.com/redis/go-redis/v9 v9.5.3/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/redis/go-redis/v9 v9.6.0 h1:NLck+Rab3AOTHw21CGRpvQpgTrAU4sgdCswqGtlhGRA= +github.com/redis/go-redis/v9 v9.6.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/test-go/testify v1.1.4 h1:Tf9lntrKUMHiXQ07qBScBTSA0dhYQlu83hswqelv1iE= github.com/test-go/testify v1.1.4/go.mod h1:rH7cfJo/47vWGdi4GPj16x3/t1xGOj2YxzmNQzk2ghU= github.com/zhuyie/golzf v0.0.0-20161112031142-8387b0307ade h1:bafvQukPrIYwYWcft4rl3WpHo3qO0/voaAgnCwgdhi0= github.com/zhuyie/golzf v0.0.0-20161112031142-8387b0307ade/go.mod h1:juNhYdla04C276MyU4zR0BA7t90ziLKPwkjDgddGYV0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/app/server/server_test.go b/internal/app/server/server_test.go index 29e42ec..8142d7d 100644 --- a/internal/app/server/server_test.go +++ b/internal/app/server/server_test.go @@ -295,6 +295,7 @@ func TestXReadCommand(t *testing.T) { action: func() (interface{}, error) { result, err := rdb.XRead(ctx, &redis.XReadArgs{ Streams: []string{"mystream", "0-0"}, + Block: -1, }).Result() if err == redis.Nil { return nil, nil @@ -316,6 +317,7 @@ func TestXReadCommand(t *testing.T) { action: func() (interface{}, error) { return rdb.XRead(ctx, &redis.XReadArgs{ Streams: []string{"mystream", "0-0"}, + Block: -1, }).Result() }, expected: []redis.XStream{{Stream: "mystream", Messages: []redis.XMessage{{ID: "1-1", Values: map[string]interface{}{"field1": "value1"}}}}}, @@ -339,6 +341,7 @@ func TestXReadCommand(t *testing.T) { streams, err := rdb.XRead(ctx, &redis.XReadArgs{ Streams: []string{"mystream2", "0-0"}, Count: 3, + Block: -1, }).Result() if err != nil { return nil, err @@ -416,6 +419,61 @@ func TestXReadCommand(t *testing.T) { }, expected: []redis.XStream{{Stream: "mystream7", Messages: []redis.XMessage{{ID: "1-1", Values: map[string]interface{}{"field1": "value1"}}}}}, }, + { + name: "XRead (multiple streams and multiple entries)", + setup: func() error { + _, err := rdb.XAdd(ctx, &redis.XAddArgs{ + Stream: "mystream20", + ID: "1-1", + Values: map[string]interface{}{"field1": "value1"}, + }).Result() + if err != nil { + return err + } + _, err = rdb.XAdd(ctx, &redis.XAddArgs{ + Stream: "mystream21", + ID: "2-2", + Values: map[string]interface{}{"field2": "value2"}, + }).Result() + if err != nil { + return err + } + _, err = rdb.XAdd(ctx, &redis.XAddArgs{ + Stream: "mystream20", + ID: "1-2", + Values: map[string]interface{}{"field1": "value1"}, + }).Result() + if err != nil { + return err + } + _, err = rdb.XAdd(ctx, &redis.XAddArgs{ + Stream: "mystream21", + ID: "2-3", + Values: map[string]interface{}{"field2": "value2"}, + }).Result() + if err != nil { + return err + } + return nil + }, + action: func() (interface{}, error) { + return rdb.XRead(ctx, &redis.XReadArgs{ + Streams: []string{"mystream20", "mystream21", "0-0", "0-0"}, + Block: -1, + }).Result() + }, + expected: []redis.XStream{{Stream: "mystream20", Messages: []redis.XMessage{ + {ID: "1-1", Values: map[string]interface{}{"field1": "value1"}}, + {ID: "1-2", Values: map[string]interface{}{"field1": "value1"}}, + }}, + { + Stream: "mystream21", Messages: []redis.XMessage{ + {ID: "2-2", Values: map[string]interface{}{"field2": "value2"}}, + {ID: "2-3", Values: map[string]interface{}{"field2": "value2"}}, + }, + }, + }, + }, { name: "XRead with BLOCK and COUNT", setup: func() error { @@ -492,6 +550,7 @@ func TestXReadCommand(t *testing.T) { action: func() (interface{}, error) { streams, err := rdb.XRead(ctx, &redis.XReadArgs{ Streams: []string{"mystream4", "mystream5", "0-0", "0-0"}, + Block: -1, }).Result() if err != nil { return nil, err diff --git a/pkg/command/xread.go b/pkg/command/xread.go index 7d1f8a9..c0cde4d 100644 --- a/pkg/command/xread.go +++ b/pkg/command/xread.go @@ -2,6 +2,7 @@ package command import ( "errors" + "reflect" "strconv" "strings" "time" @@ -18,7 +19,7 @@ type XRead struct { type XReadOptions struct { Count uint64 - Block time.Duration + Block *time.Duration Streams map[string]string } @@ -71,7 +72,8 @@ func (x *XRead) parseArgs(args []*resp.Resp) (*XReadOptions, error) { if err != nil { return nil, errors.New("ERR value is not an integer or out of range") } - opts.Block = time.Duration(block) * time.Millisecond + duration := time.Duration(block) * time.Millisecond + opts.Block = &duration case "STREAMS": i++ streamCount := (len(args) - i) / 2 @@ -113,7 +115,7 @@ func (x *XRead) readStreams(opts *XReadOptions) (map[string][]keyval.StreamEntry // Ignore this case as $ means only new entries } else if lastID == "+" { entries = append(entries, stream.Range(stream.LastID(), "+", 1)...) - } else if opts.Block <= 0 { + } else if opts.Block == nil { // TODO: remove this condition entries = append(entries, stream.Range(lastID, "+", opts.Count)...) } if len(entries) > 0 { @@ -126,7 +128,7 @@ func (x *XRead) readStreams(opts *XReadOptions) (map[string][]keyval.StreamEntry return result, nil } - if opts.Block > 0 { + if opts.Block != nil { return x.blockingRead(opts) } @@ -155,27 +157,35 @@ func (x *XRead) blockingRead(opts *XReadOptions) (map[string][]keyval.StreamEntr subscriptions[streamName] = ch } - timer := time.NewTimer(opts.Block) - defer timer.Stop() + var timer *time.Timer + if opts.Block != nil && *opts.Block > 0 { + timer = time.NewTimer(*opts.Block) + defer timer.Stop() + } - for { - select { - case <-timer.C: - return nil, nil - default: - for streamName, ch := range subscriptions { - select { - case entry := <-ch: - result[streamName] = []keyval.StreamEntry{entry} - return result, nil - default: - // No entry available for this stream, continue to next - } - } - // Small sleep to prevent busy-waiting - time.Sleep(1 * time.Millisecond) - } + cases := make([]reflect.SelectCase, 0, len(subscriptions)+1) + for _, ch := range subscriptions { + cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}) + } + + if timer != nil { + cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(timer.C)}) } + + chosen, value, _ := reflect.Select(cases) + if chosen == len(cases)-1 && timer != nil { + // Timeout + return nil, nil + } + + if chosen < len(subscriptions) { + streamName := x.streamOrder[chosen] + entry := value.Interface().(keyval.StreamEntry) + result[streamName] = []keyval.StreamEntry{entry} + return result, nil + } + + return nil, nil } func (x *XRead) writeResult(cl *client.Client, wr *resp.Writer, result map[string][]keyval.StreamEntry) error { @@ -224,7 +234,7 @@ func (x *XRead) IsBlocking(args []*resp.Resp) bool { for i, arg := range args { if strings.ToUpper(arg.String()) == "BLOCK" { if i+1 < len(args) { - if block, err := strconv.ParseInt(args[i+1].String(), 10, 64); err == nil && block > 0 { + if _, err := strconv.ParseInt(args[i+1].String(), 10, 64); err == nil { return true } }