diff --git a/go.mod b/go.mod index 58e2fe11ce..1f239b2820 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ replace github.com/ethereum/go-ethereum => ./go-ethereum require ( github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible github.com/Shopify/toxiproxy v2.1.4+incompatible - github.com/alicebob/miniredis/v2 v2.21.0 + github.com/alicebob/miniredis/v2 v2.32.1 github.com/andybalholm/brotli v1.0.4 github.com/aws/aws-sdk-go-v2 v1.16.4 github.com/aws/aws-sdk-go-v2/config v1.15.5 @@ -262,7 +262,7 @@ require ( github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 // indirect github.com/whyrusleeping/multiaddr-filter v0.0.0-20160516205228-e903e4adabd7 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect - github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/otel v1.7.0 // indirect go.opentelemetry.io/otel/exporters/jaeger v1.7.0 // indirect @@ -319,7 +319,7 @@ require ( github.com/go-redis/redis/v8 v8.11.4 github.com/go-stack/stack v1.8.1 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect - github.com/google/uuid v1.3.1 // indirect + github.com/google/uuid v1.3.1 github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/go-bexpr v0.1.10 // indirect github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d // indirect diff --git a/go.sum b/go.sum index 39b1caffe4..283f305312 100644 --- a/go.sum +++ b/go.sum @@ -83,8 +83,8 @@ github.com/alexbrainman/goissue34681 v0.0.0-20191006012335-3fc7a47baff5 h1:iW0a5 github.com/alexbrainman/goissue34681 v0.0.0-20191006012335-3fc7a47baff5/go.mod h1:Y2QMoi1vgtOIfc+6DhrMOGkLoGzqSV2rKp4Sm+opsyA= github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= -github.com/alicebob/miniredis/v2 v2.21.0 h1:CdmwIlKUWFBDS+4464GtQiQ0R1vpzOgu4Vnd74rBL7M= -github.com/alicebob/miniredis/v2 v2.21.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88= +github.com/alicebob/miniredis/v2 v2.32.1 h1:Bz7CciDnYSaa0mX5xODh6GUITRSx+cVhjNoOR4JssBo= +github.com/alicebob/miniredis/v2 v2.32.1/go.mod h1:AqkLNAfUm0K07J28hnAyyQKf/x0YkCY/g5DCtuL01Mw= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= @@ -1688,8 +1688,8 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw= -github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= diff --git a/pubsub/consumer.go b/pubsub/consumer.go new file mode 100644 index 0000000000..3de313f120 --- /dev/null +++ b/pubsub/consumer.go @@ -0,0 +1,177 @@ +package pubsub + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/go-redis/redis/v8" + "github.com/google/uuid" + "github.com/offchainlabs/nitro/util/redisutil" + "github.com/offchainlabs/nitro/util/stopwaiter" + "github.com/spf13/pflag" +) + +type ConsumerConfig struct { + // Timeout of result entry in Redis. + ResponseEntryTimeout time.Duration `koanf:"response-entry-timeout"` + // Duration after which consumer is considered to be dead if heartbeat + // is not updated. + KeepAliveTimeout time.Duration `koanf:"keepalive-timeout"` + // Redis url for Redis streams and locks. + RedisURL string `koanf:"redis-url"` + // Redis stream name. + RedisStream string `koanf:"redis-stream"` + // Redis consumer group name. + RedisGroup string `koanf:"redis-group"` +} + +var DefaultConsumerConfig = &ConsumerConfig{ + ResponseEntryTimeout: time.Hour, + KeepAliveTimeout: 5 * time.Minute, + RedisStream: "", + RedisGroup: "", +} + +var DefaultTestConsumerConfig = &ConsumerConfig{ + RedisStream: "", + RedisGroup: "", + ResponseEntryTimeout: time.Minute, + KeepAliveTimeout: 30 * time.Millisecond, +} + +func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet) { + f.Duration(prefix+".response-entry-timeout", DefaultConsumerConfig.ResponseEntryTimeout, "timeout for response entry") + f.Duration(prefix+".keepalive-timeout", DefaultConsumerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed") + f.String(prefix+".redis-url", DefaultConsumerConfig.RedisURL, "redis url for redis stream") + f.String(prefix+".redis-stream", DefaultConsumerConfig.RedisStream, "redis stream name to read from") + f.String(prefix+".redis-group", DefaultConsumerConfig.RedisGroup, "redis stream consumer group name") +} + +// Consumer implements a consumer for redis stream provides heartbeat to +// indicate it is alive. +type Consumer[Request any, Response any] struct { + stopwaiter.StopWaiter + id string + client redis.UniversalClient + cfg *ConsumerConfig +} + +type Message[Request any] struct { + ID string + Value Request +} + +func NewConsumer[Request any, Response any](ctx context.Context, cfg *ConsumerConfig) (*Consumer[Request, Response], error) { + if cfg.RedisURL == "" { + return nil, fmt.Errorf("redis url cannot be empty") + } + if cfg.RedisStream == "" { + return nil, fmt.Errorf("redis stream name cannot be empty") + } + if cfg.RedisGroup == "" { + return nil, fmt.Errorf("redis group name cannot be emtpy") + } + c, err := redisutil.RedisClientFromURL(cfg.RedisURL) + if err != nil { + return nil, err + } + consumer := &Consumer[Request, Response]{ + id: uuid.NewString(), + client: c, + cfg: cfg, + } + return consumer, nil +} + +// Start starts the consumer to iteratively perform heartbeat in configured intervals. +func (c *Consumer[Request, Response]) Start(ctx context.Context) { + c.StopWaiter.Start(ctx, c) + c.StopWaiter.CallIteratively( + func(ctx context.Context) time.Duration { + c.heartBeat(ctx) + return c.cfg.KeepAliveTimeout / 10 + }, + ) +} + +func (c *Consumer[Request, Response]) StopAndWait() { + c.StopWaiter.StopAndWait() +} + +func heartBeatKey(id string) string { + return fmt.Sprintf("consumer:%s:heartbeat", id) +} + +func (c *Consumer[Request, Response]) heartBeatKey() string { + return heartBeatKey(c.id) +} + +// heartBeat updates the heartBeat key indicating aliveness. +func (c *Consumer[Request, Response]) heartBeat(ctx context.Context) { + if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), 2*c.cfg.KeepAliveTimeout).Err(); err != nil { + l := log.Info + if ctx.Err() != nil { + l = log.Error + } + l("Updating heardbeat", "consumer", c.id, "error", err) + } +} + +// Consumer first checks it there exists pending message that is claimed by +// unresponsive consumer, if not then reads from the stream. +func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Request], error) { + res, err := c.client.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: c.cfg.RedisGroup, + Consumer: c.id, + // Receive only messages that were never delivered to any other consumer, + // that is, only new messages. + Streams: []string{c.cfg.RedisStream, ">"}, + Count: 1, + Block: time.Millisecond, // 0 seems to block the read instead of immediately returning + }).Result() + if errors.Is(err, redis.Nil) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("reading message for consumer: %q: %w", c.id, err) + } + if len(res) != 1 || len(res[0].Messages) != 1 { + return nil, fmt.Errorf("redis returned entries: %+v, for querying single message", res) + } + log.Debug(fmt.Sprintf("Consumer: %s consuming message: %s", c.id, res[0].Messages[0].ID)) + var ( + value = res[0].Messages[0].Values[messageKey] + data, ok = (value).(string) + ) + if !ok { + return nil, fmt.Errorf("casting request to string: %w", err) + } + var req Request + if err := json.Unmarshal([]byte(data), &req); err != nil { + return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err) + } + + return &Message[Request]{ + ID: res[0].Messages[0].ID, + Value: req, + }, nil +} + +func (c *Consumer[Request, Response]) SetResult(ctx context.Context, messageID string, result Response) error { + resp, err := json.Marshal(result) + if err != nil { + return fmt.Errorf("marshaling result: %w", err) + } + acquired, err := c.client.SetNX(ctx, messageID, resp, c.cfg.ResponseEntryTimeout).Result() + if err != nil || !acquired { + return fmt.Errorf("setting result for message: %v, error: %w", messageID, err) + } + if _, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result(); err != nil { + return fmt.Errorf("acking message: %v, error: %w", messageID, err) + } + return nil +} diff --git a/pubsub/producer.go b/pubsub/producer.go new file mode 100644 index 0000000000..13a4553e2f --- /dev/null +++ b/pubsub/producer.go @@ -0,0 +1,313 @@ +// Package pubsub implements publisher/subscriber model (one to many). +// During normal operation, publisher returns "Promise" when publishing a +// message, which will return resposne from consumer when awaited. +// If the consumer processing the request becomes inactive, message is +// re-inserted (if EnableReproduce flag is enabled), and will be picked up by +// another consumer. +// We are assuming here that keeepAliveTimeout is set to some sensible value +// and once consumer becomes inactive, it doesn't activate without restart. +package pubsub + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/go-redis/redis/v8" + "github.com/google/uuid" + "github.com/offchainlabs/nitro/util/containers" + "github.com/offchainlabs/nitro/util/redisutil" + "github.com/offchainlabs/nitro/util/stopwaiter" + "github.com/spf13/pflag" +) + +const ( + messageKey = "msg" + defaultGroup = "default_consumer_group" +) + +type Producer[Request any, Response any] struct { + stopwaiter.StopWaiter + id string + client redis.UniversalClient + cfg *ProducerConfig + + promisesLock sync.RWMutex + promises map[string]*containers.Promise[Response] + + // Used for running checks for pending messages with inactive consumers + // and checking responses from consumers iteratively for the first time when + // Produce is called. + once sync.Once +} + +type ProducerConfig struct { + // When enabled, messages that are sent to consumers that later die before + // processing them, will be re-inserted into the stream to be proceesed by + // another consumer + EnableReproduce bool `koanf:"enable-reproduce"` + RedisURL string `koanf:"redis-url"` + // Redis stream name. + RedisStream string `koanf:"redis-stream"` + // Interval duration in which producer checks for pending messages delivered + // to the consumers that are currently inactive. + CheckPendingInterval time.Duration `koanf:"check-pending-interval"` + // Duration after which consumer is considered to be dead if heartbeat + // is not updated. + KeepAliveTimeout time.Duration `koanf:"keepalive-timeout"` + // Interval duration for checking the result set by consumers. + CheckResultInterval time.Duration `koanf:"check-result-interval"` + // Redis consumer group name. + RedisGroup string `koanf:"redis-group"` +} + +var DefaultProducerConfig = &ProducerConfig{ + EnableReproduce: true, + RedisStream: "", + RedisGroup: "", + CheckPendingInterval: time.Second, + KeepAliveTimeout: 5 * time.Minute, + CheckResultInterval: 5 * time.Second, +} + +var DefaultTestProducerConfig = &ProducerConfig{ + EnableReproduce: true, + RedisStream: "", + RedisGroup: "", + CheckPendingInterval: 10 * time.Millisecond, + KeepAliveTimeout: 100 * time.Millisecond, + CheckResultInterval: 5 * time.Millisecond, +} + +func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) { + f.Bool(prefix+".enable-reproduce", DefaultProducerConfig.EnableReproduce, "when enabled, messages with dead consumer will be re-inserted into the stream") + f.String(prefix+".redis-url", DefaultProducerConfig.RedisURL, "redis url for redis stream") + f.Duration(prefix+".check-pending-interval", DefaultProducerConfig.CheckPendingInterval, "interval in which producer checks pending messages whether consumer processing them is inactive") + f.Duration(prefix+".keepalive-timeout", DefaultProducerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed") + f.String(prefix+".redis-stream", DefaultProducerConfig.RedisStream, "redis stream name to read from") + f.String(prefix+".redis-group", DefaultProducerConfig.RedisGroup, "redis stream consumer group name") +} + +func NewProducer[Request any, Response any](cfg *ProducerConfig) (*Producer[Request, Response], error) { + if cfg.RedisURL == "" { + return nil, fmt.Errorf("redis url cannot be empty") + } + if cfg.RedisStream == "" { + return nil, fmt.Errorf("redis stream cannot be emtpy") + } + if cfg.RedisGroup == "" { + return nil, fmt.Errorf("redis group cannot be empty") + } + c, err := redisutil.RedisClientFromURL(cfg.RedisURL) + if err != nil { + return nil, err + } + return &Producer[Request, Response]{ + id: uuid.NewString(), + client: c, + cfg: cfg, + promises: make(map[string]*containers.Promise[Response]), + }, nil +} + +func (p *Producer[Request, Response]) errorPromisesFor(msgs []*Message[Request]) { + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + for _, msg := range msgs { + if promise, found := p.promises[msg.ID]; found { + promise.ProduceError(fmt.Errorf("internal error, consumer died while serving the request")) + delete(p.promises, msg.ID) + } + } +} + +// checkAndReproduce reproduce pending messages that were sent to consumers +// that are currently inactive. +func (p *Producer[Request, Response]) checkAndReproduce(ctx context.Context) time.Duration { + msgs, err := p.checkPending(ctx) + if err != nil { + log.Error("Checking pending messages", "error", err) + return p.cfg.CheckPendingInterval + } + if len(msgs) == 0 { + return p.cfg.CheckPendingInterval + } + if !p.cfg.EnableReproduce { + p.errorPromisesFor(msgs) + return p.cfg.CheckPendingInterval + } + acked := make(map[string]Request) + for _, msg := range msgs { + if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil { + log.Error("ACKing message", "error", err) + continue + } + acked[msg.ID] = msg.Value + } + for k, v := range acked { + // Only re-insert messages that were removed the the pending list first. + _, err := p.reproduce(ctx, v, k) + if err != nil { + log.Error("Re-inserting pending messages with inactive consumers", "error", err) + } + } + return p.cfg.CheckPendingInterval +} + +// checkResponses checks iteratively whether response for the promise is ready. +func (p *Producer[Request, Response]) checkResponses(ctx context.Context) time.Duration { + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + for id, promise := range p.promises { + res, err := p.client.Get(ctx, id).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + continue + } + log.Error("Error reading value in redis", "key", id, "error", err) + } + var resp Response + if err := json.Unmarshal([]byte(res), &resp); err != nil { + log.Error("Error unmarshaling", "value", res, "error", err) + continue + } + promise.Produce(resp) + delete(p.promises, id) + } + return p.cfg.CheckResultInterval +} + +func (p *Producer[Request, Response]) Start(ctx context.Context) { + p.StopWaiter.Start(ctx, p) +} + +func (p *Producer[Request, Response]) promisesLen() int { + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + return len(p.promises) +} + +// reproduce is used when Producer claims ownership on the pending +// message that was sent to inactive consumer and reinserts it into the stream, +// so that seamlessly return the answer in the same promise. +func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Request, oldKey string) (*containers.Promise[Response], error) { + val, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("marshaling value: %w", err) + } + id, err := p.client.XAdd(ctx, &redis.XAddArgs{ + Stream: p.cfg.RedisStream, + Values: map[string]any{messageKey: val}, + }).Result() + if err != nil { + return nil, fmt.Errorf("adding values to redis: %w", err) + } + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + promise := p.promises[oldKey] + if oldKey != "" && promise == nil { + // This will happen if the old consumer became inactive but then ack_d + // the message afterwards. + return nil, fmt.Errorf("error reproducing the message, could not find existing one") + } + if oldKey == "" || promise == nil { + pr := containers.NewPromise[Response](nil) + promise = &pr + } + delete(p.promises, oldKey) + p.promises[id] = promise + return promise, nil +} + +func (p *Producer[Request, Response]) Produce(ctx context.Context, value Request) (*containers.Promise[Response], error) { + p.once.Do(func() { + p.StopWaiter.CallIteratively(p.checkAndReproduce) + p.StopWaiter.CallIteratively(p.checkResponses) + }) + return p.reproduce(ctx, value, "") +} + +// Check if a consumer is with specified ID is alive. +func (p *Producer[Request, Response]) isConsumerAlive(ctx context.Context, consumerID string) bool { + if _, err := p.client.Get(ctx, heartBeatKey(consumerID)).Int64(); err != nil { + return false + } + return true +} + +func (p *Producer[Request, Response]) havePromiseFor(messageID string) bool { + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + _, found := p.promises[messageID] + return found +} + +func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Message[Request], error) { + pendingMessages, err := p.client.XPendingExt(ctx, &redis.XPendingExtArgs{ + Stream: p.cfg.RedisStream, + Group: p.cfg.RedisGroup, + Start: "-", + End: "+", + Count: 100, + }).Result() + + if err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("querying pending messages: %w", err) + } + if len(pendingMessages) == 0 { + return nil, nil + } + // IDs of the pending messages with inactive consumers. + var ids []string + active := make(map[string]bool) + for _, msg := range pendingMessages { + // Ignore messages not produced by this producer. + if !p.havePromiseFor(msg.ID) { + continue + } + alive, found := active[msg.Consumer] + if !found { + alive = p.isConsumerAlive(ctx, msg.Consumer) + active[msg.Consumer] = alive + } + if alive { + continue + } + ids = append(ids, msg.ID) + } + if len(ids) == 0 { + log.Trace("There are no pending messages with inactive consumers") + return nil, nil + } + log.Info("Attempting to claim", "messages", ids) + claimedMsgs, err := p.client.XClaim(ctx, &redis.XClaimArgs{ + Stream: p.cfg.RedisStream, + Group: p.cfg.RedisGroup, + Consumer: p.id, + MinIdle: p.cfg.KeepAliveTimeout, + Messages: ids, + }).Result() + if err != nil { + return nil, fmt.Errorf("claiming ownership on messages: %v, error: %w", ids, err) + } + var res []*Message[Request] + for _, msg := range claimedMsgs { + data, ok := (msg.Values[messageKey]).(string) + if !ok { + return nil, fmt.Errorf("casting request: %v to bytes", msg.Values[messageKey]) + } + var req Request + if err := json.Unmarshal([]byte(data), &req); err != nil { + return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err) + } + res = append(res, &Message[Request]{ + ID: msg.ID, + Value: req, + }) + } + return res, nil +} diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go new file mode 100644 index 0000000000..c8968b4e45 --- /dev/null +++ b/pubsub/pubsub_test.go @@ -0,0 +1,330 @@ +package pubsub + +import ( + "context" + "errors" + "fmt" + "sort" + "testing" + + "github.com/ethereum/go-ethereum/log" + "github.com/go-redis/redis/v8" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/offchainlabs/nitro/util/containers" + "github.com/offchainlabs/nitro/util/redisutil" +) + +var ( + consumersCount = 10 + messagesCount = 100 +) + +type testRequest struct { + Request string +} + +type testResponse struct { + Response string +} + +func createGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) { + t.Helper() + if _, err := client.XGroupCreateMkStream(ctx, streamName, groupName, "$").Result(); err != nil { + t.Fatalf("Error creating stream group: %v", err) + } +} + +func destroyGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) { + t.Helper() + if _, err := client.XGroupDestroy(ctx, streamName, groupName).Result(); err != nil { + log.Debug("Error destroying a stream group", "error", err) + } +} + +type configOpt interface { + apply(consCfg *ConsumerConfig, prodCfg *ProducerConfig) +} + +type disableReproduce struct{} + +func (e *disableReproduce) apply(_ *ConsumerConfig, prodCfg *ProducerConfig) { + prodCfg.EnableReproduce = false +} + +func producerCfg() *ProducerConfig { + return &ProducerConfig{ + EnableReproduce: DefaultTestProducerConfig.EnableReproduce, + CheckPendingInterval: DefaultTestProducerConfig.CheckPendingInterval, + KeepAliveTimeout: DefaultTestProducerConfig.KeepAliveTimeout, + CheckResultInterval: DefaultTestProducerConfig.CheckResultInterval, + } +} + +func consumerCfg() *ConsumerConfig { + return &ConsumerConfig{ + ResponseEntryTimeout: DefaultTestConsumerConfig.ResponseEntryTimeout, + KeepAliveTimeout: DefaultTestConsumerConfig.KeepAliveTimeout, + } +} + +func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[testRequest, testResponse], []*Consumer[testRequest, testResponse]) { + t.Helper() + redisURL := redisutil.CreateTestRedis(ctx, t) + prodCfg, consCfg := producerCfg(), consumerCfg() + prodCfg.RedisURL, consCfg.RedisURL = redisURL, redisURL + streamName := uuid.NewString() + groupName := fmt.Sprintf("group_%s", streamName) + prodCfg.RedisGroup, consCfg.RedisGroup = groupName, groupName + prodCfg.RedisStream, consCfg.RedisStream = streamName, streamName + for _, o := range opts { + o.apply(consCfg, prodCfg) + } + producer, err := NewProducer[testRequest, testResponse](prodCfg) + if err != nil { + t.Fatalf("Error creating new producer: %v", err) + } + + var consumers []*Consumer[testRequest, testResponse] + for i := 0; i < consumersCount; i++ { + c, err := NewConsumer[testRequest, testResponse](ctx, consCfg) + if err != nil { + t.Fatalf("Error creating new consumer: %v", err) + } + consumers = append(consumers, c) + } + createGroup(ctx, t, streamName, groupName, producer.client) + t.Cleanup(func() { + ctx := context.Background() + destroyGroup(ctx, t, streamName, groupName, producer.client) + var keys []string + for _, c := range consumers { + keys = append(keys, c.heartBeatKey()) + } + if _, err := producer.client.Del(ctx, keys...).Result(); err != nil { + log.Debug("Error deleting heartbeat keys", "error", err) + } + }) + return producer, consumers +} + +func messagesMaps(n int) []map[string]string { + ret := make([]map[string]string, n) + for i := 0; i < n; i++ { + ret[i] = make(map[string]string) + } + return ret +} + +func wantMessages(n int) []string { + var ret []string + for i := 0; i < n; i++ { + ret = append(ret, fmt.Sprintf("msg: %d", i)) + } + sort.Strings(ret) + return ret +} + +func flatten(responses [][]string) []string { + var ret []string + for _, v := range responses { + ret = append(ret, v...) + } + sort.Strings(ret) + return ret +} + +func produceMessages(ctx context.Context, msgs []string, producer *Producer[testRequest, testResponse]) ([]*containers.Promise[testResponse], error) { + var promises []*containers.Promise[testResponse] + for i := 0; i < messagesCount; i++ { + promise, err := producer.Produce(ctx, testRequest{Request: msgs[i]}) + if err != nil { + return nil, err + } + promises = append(promises, promise) + } + return promises, nil +} + +func awaitResponses(ctx context.Context, promises []*containers.Promise[testResponse]) ([]string, error) { + var ( + responses []string + errs []error + ) + for _, p := range promises { + res, err := p.Await(ctx) + if err != nil { + errs = append(errs, err) + continue + } + responses = append(responses, res.Response) + } + return responses, errors.Join(errs...) +} + +// consume messages from every consumer except stopped ones. +func consume(ctx context.Context, t *testing.T, consumers []*Consumer[testRequest, testResponse]) ([]map[string]string, [][]string) { + t.Helper() + gotMessages := messagesMaps(consumersCount) + wantResponses := make([][]string, consumersCount) + for idx := 0; idx < consumersCount; idx++ { + if consumers[idx].Stopped() { + continue + } + idx, c := idx, consumers[idx] + c.Start(ctx) + c.StopWaiter.LaunchThread( + func(ctx context.Context) { + for { + + res, err := c.Consume(ctx) + if err != nil { + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + t.Errorf("Consume() unexpected error: %v", err) + continue + } + return + } + if res == nil { + continue + } + gotMessages[idx][res.ID] = res.Value.Request + resp := fmt.Sprintf("result for: %v", res.ID) + if err := c.SetResult(ctx, res.ID, testResponse{Response: resp}); err != nil { + t.Errorf("Error setting a result: %v", err) + } + wantResponses[idx] = append(wantResponses[idx], resp) + } + }) + } + return gotMessages, wantResponses +} + +func TestRedisProduce(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + name string + killConsumers bool + }{ + { + name: "all consumers are active", + killConsumers: false, + }, + { + name: "some consumers killed, others should take over their work", + killConsumers: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + producer, consumers := newProducerConsumers(ctx, t) + producer.Start(ctx) + wantMsgs := wantMessages(messagesCount) + promises, err := produceMessages(ctx, wantMsgs, producer) + if err != nil { + t.Fatalf("Error producing messages: %v", err) + } + if tc.killConsumers { + // Consumer messages in every third consumer but don't ack them to check + // that other consumers will claim ownership on those messages. + for i := 0; i < len(consumers); i += 3 { + if _, err := consumers[i].Consume(ctx); err != nil { + t.Errorf("Error consuming message: %v", err) + } + consumers[i].StopAndWait() + } + + } + gotMessages, wantResponses := consume(ctx, t, consumers) + gotResponses, err := awaitResponses(ctx, promises) + if err != nil { + t.Fatalf("Error awaiting responses: %v", err) + } + producer.StopAndWait() + for _, c := range consumers { + c.StopWaiter.StopAndWait() + } + got, err := mergeValues(gotMessages) + if err != nil { + t.Fatalf("mergeMaps() unexpected error: %v", err) + } + + if diff := cmp.Diff(wantMsgs, got); diff != "" { + t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) + } + wantResp := flatten(wantResponses) + sort.Strings(gotResponses) + if diff := cmp.Diff(wantResp, gotResponses); diff != "" { + t.Errorf("Unexpected diff in responses:\n%s\n", diff) + } + if cnt := producer.promisesLen(); cnt != 0 { + t.Errorf("Producer still has %d unfullfilled promises", cnt) + } + }) + } +} + +func TestRedisReproduceDisabled(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + producer, consumers := newProducerConsumers(ctx, t, &disableReproduce{}) + producer.Start(ctx) + wantMsgs := wantMessages(messagesCount) + promises, err := produceMessages(ctx, wantMsgs, producer) + if err != nil { + t.Fatalf("Error producing messages: %v", err) + } + + // Consumer messages in every third consumer but don't ack them to check + // that other consumers will claim ownership on those messages. + for i := 0; i < len(consumers); i += 3 { + if _, err := consumers[i].Consume(ctx); err != nil { + t.Errorf("Error consuming message: %v", err) + } + consumers[i].StopAndWait() + } + + gotMessages, _ := consume(ctx, t, consumers) + gotResponses, err := awaitResponses(ctx, promises) + if err == nil { + t.Fatalf("All promises were fullfilled with reproduce disabled and some consumers killed") + } + producer.StopAndWait() + for _, c := range consumers { + c.StopWaiter.StopAndWait() + } + got, err := mergeValues(gotMessages) + if err != nil { + t.Fatalf("mergeMaps() unexpected error: %v", err) + } + wantMsgCnt := messagesCount - ((consumersCount + 2) / 3) + if len(got) != wantMsgCnt { + t.Fatalf("Got: %d messages, want %d", len(got), wantMsgCnt) + } + if len(gotResponses) != wantMsgCnt { + t.Errorf("Got %d responses want: %d\n", len(gotResponses), wantMsgCnt) + } + if cnt := producer.promisesLen(); cnt != 0 { + t.Errorf("Producer still has %d unfullfilled promises", cnt) + } +} + +// mergeValues merges maps from the slice and returns their values. +// Returns and error if there exists duplicate key. +func mergeValues(messages []map[string]string) ([]string, error) { + res := make(map[string]any) + var ret []string + for _, m := range messages { + for k, v := range m { + if _, found := res[k]; found { + return nil, fmt.Errorf("duplicate key: %v", k) + } + res[k] = v + ret = append(ret, v) + } + } + sort.Strings(ret) + return ret, nil +}