From 4676a7cada754b0dd268b5f8cc0a24badf73d95f Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Mon, 18 Mar 2024 23:27:39 +0100 Subject: [PATCH 01/33] Implement publisher subscriber library using redis streams --- go.mod | 6 +- go.sum | 8 +- pubsub/consumer.go | 198 ++++++++++++++++++++++++++++++++++++++++ pubsub/producer.go | 52 +++++++++++ pubsub/pubsub_test.go | 207 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 464 insertions(+), 7 deletions(-) create mode 100644 pubsub/consumer.go create mode 100644 pubsub/producer.go create mode 100644 pubsub/pubsub_test.go diff --git a/go.mod b/go.mod index cf9e61f9b9..0990bbd70d 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ replace github.com/ethereum/go-ethereum => ./go-ethereum require ( github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible github.com/Shopify/toxiproxy v2.1.4+incompatible - github.com/alicebob/miniredis/v2 v2.21.0 + github.com/alicebob/miniredis/v2 v2.32.1 github.com/andybalholm/brotli v1.0.4 github.com/aws/aws-sdk-go-v2 v1.16.4 github.com/aws/aws-sdk-go-v2/config v1.15.5 @@ -260,7 +260,7 @@ require ( github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 // indirect github.com/whyrusleeping/multiaddr-filter v0.0.0-20160516205228-e903e4adabd7 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect - github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/otel v1.7.0 // indirect go.opentelemetry.io/otel/exporters/jaeger v1.7.0 // indirect @@ -317,7 +317,7 @@ require ( github.com/go-redis/redis/v8 v8.11.4 github.com/go-stack/stack v1.8.1 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect - github.com/google/uuid v1.3.1 // indirect + github.com/google/uuid v1.3.1 github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/go-bexpr v0.1.10 // indirect github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d // indirect diff --git a/go.sum b/go.sum index f2b4c668cc..d589fb16e4 100644 --- a/go.sum +++ b/go.sum @@ -83,8 +83,8 @@ github.com/alexbrainman/goissue34681 v0.0.0-20191006012335-3fc7a47baff5 h1:iW0a5 github.com/alexbrainman/goissue34681 v0.0.0-20191006012335-3fc7a47baff5/go.mod h1:Y2QMoi1vgtOIfc+6DhrMOGkLoGzqSV2rKp4Sm+opsyA= github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= -github.com/alicebob/miniredis/v2 v2.21.0 h1:CdmwIlKUWFBDS+4464GtQiQ0R1vpzOgu4Vnd74rBL7M= -github.com/alicebob/miniredis/v2 v2.21.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88= +github.com/alicebob/miniredis/v2 v2.32.1 h1:Bz7CciDnYSaa0mX5xODh6GUITRSx+cVhjNoOR4JssBo= +github.com/alicebob/miniredis/v2 v2.32.1/go.mod h1:AqkLNAfUm0K07J28hnAyyQKf/x0YkCY/g5DCtuL01Mw= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= @@ -1684,8 +1684,8 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw= -github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= diff --git a/pubsub/consumer.go b/pubsub/consumer.go new file mode 100644 index 0000000000..2978ef06b6 --- /dev/null +++ b/pubsub/consumer.go @@ -0,0 +1,198 @@ +package pubsub + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/go-redis/redis/v8" + "github.com/google/uuid" +) + +var ( + // Intervals in which consumer will update heartbeat. + KeepAliveInterval = 30 * time.Second + // Duration after which consumer is considered to be dead if heartbeat + // is not updated. + KeepAliveTimeout = 5 * time.Minute + // Key for locking pending messages. + pendingMessagesKey = "lock:pending" +) + +type Consumer struct { + id string + streamName string + groupName string + client *redis.Client +} + +type Message struct { + ID string + Value any +} + +func NewConsumer(ctx context.Context, id, streamName, url string) (*Consumer, error) { + c, err := clientFromURL(url) + if err != nil { + return nil, err + } + if id == "" { + id = uuid.NewString() + } + + consumer := &Consumer{ + id: id, + streamName: streamName, + groupName: "default", + client: c, + } + go consumer.keepAlive(ctx) + return consumer, nil +} + +func keepAliveKey(id string) string { + return fmt.Sprintf("consumer:%s:heartbeat", id) +} + +func (c *Consumer) keepAliveKey() string { + return keepAliveKey(c.id) +} + +// keepAlive polls in keepAliveIntervals and updates heartbeat entry for itself. +func (c *Consumer) keepAlive(ctx context.Context) { + log.Info("Consumer polling for heartbeat updates", "id", c.id) + for { + if err := c.client.Set(ctx, c.keepAliveKey(), time.Now().UnixMilli(), KeepAliveTimeout).Err(); err != nil { + log.Error("Updating heardbeat", "consumer", c.id, "error", err) + } + select { + case <-ctx.Done(): + log.Error("Error keeping alive", "error", ctx.Err()) + return + case <-time.After(KeepAliveInterval): + } + } +} + +// Consumer first checks it there exists pending message that is claimed by +// unresponsive consumer, if not then reads from the stream. +func (c *Consumer) Consume(ctx context.Context) (*Message, error) { + log.Debug("Attempting to consume a message", "consumer-id", c.id) + msg, err := c.checkPending(ctx) + if err != nil { + return nil, fmt.Errorf("consumer: %v checking pending messages with unavailable consumer: %w", c.id, err) + } + if msg != nil { + return msg, nil + } + res, err := c.client.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: c.groupName, + Consumer: c.id, + // Receive only messages that were never delivered to any other consumer, + // that is, only new messages. + Streams: []string{c.streamName, ">"}, + Count: 1, + Block: time.Millisecond, // 0 seems to block the read instead of immediately returning + }).Result() + if errors.Is(err, redis.Nil) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("reading message for consumer: %q: %w", c.id, err) + } + if len(res) != 1 || len(res[0].Messages) != 1 { + return nil, fmt.Errorf("redis returned entries: %+v, for querying single message", res) + } + log.Debug(fmt.Sprintf("Consumer: %s consuming message: %s", c.id, res[0].Messages[0].ID)) + return &Message{ + ID: res[0].Messages[0].ID, + Value: res[0].Messages[0].Values[msgKey], + }, nil +} + +func (c *Consumer) ACK(ctx context.Context, messageID string) error { + log.Info("ACKing message", "consumer-id", c.id, "message-sid", messageID) + _, err := c.client.XAck(ctx, c.streamName, c.groupName, messageID).Result() + return err +} + +// Check if a consumer is with specified ID is alive. +func (c *Consumer) isConsumerAlive(ctx context.Context, consumerID string) bool { + val, err := c.client.Get(ctx, keepAliveKey(consumerID)).Int64() + if err != nil { + return false + } + return time.Now().UnixMilli()-val < 2*int64(KeepAliveTimeout.Milliseconds()) +} + +func (c *Consumer) lockPending(ctx context.Context, consumerID string) bool { + acquired, err := c.client.SetNX(ctx, pendingMessagesKey, consumerID, KeepAliveInterval).Result() + if err != nil || !acquired { + return false + } + return true +} + +func (c *Consumer) unlockPending(ctx context.Context) { + log.Debug("Releasing lock", "consumer-id", c.id) + c.client.Del(ctx, pendingMessagesKey) + +} + +// checkPending lists pending messages, and checks unavailable consumers that +// have ownership on pending message. +// If such message and consumer exists, it claims ownership on it. +func (c *Consumer) checkPending(ctx context.Context) (*Message, error) { + // Locking pending list avoid the race where two instances query pending + // list and try to claim ownership on the same message. + if !c.lockPending(ctx, c.id) { + return nil, nil + } + log.Info("Consumer acquired pending lock", "consumer=id", c.id) + defer c.unlockPending(ctx) + pendingMessages, err := c.client.XPendingExt(ctx, &redis.XPendingExtArgs{ + Stream: c.streamName, + Group: c.groupName, + Start: "-", + End: "+", + Count: 100, + }).Result() + log.Info("Pending messages", "consumer", c.id, "pendingMessages", pendingMessages, "error", err) + + if err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("querying pending messages: %w", err) + } + if len(pendingMessages) == 0 { + return nil, nil + } + for _, msg := range pendingMessages { + if !c.isConsumerAlive(ctx, msg.Consumer) { + log.Debug("Consumer is not alive", "id", msg.Consumer) + msgs, err := c.client.XClaim(ctx, &redis.XClaimArgs{ + Stream: c.streamName, + Group: c.groupName, + Consumer: c.id, + MinIdle: KeepAliveTimeout, + Messages: []string{msg.ID}, + }).Result() + if err != nil { + log.Error("Error claiming ownership on message", "id", msg.ID, "consumer", c.id, "error", err) + continue + } + if len(msgs) != 1 { + log.Error("Attempted to claim ownership on single messsage", "id", msg.ID, "number of received messages", len(msgs)) + if len(msgs) == 0 { + continue + } + } + log.Info(fmt.Sprintf("Consumer: %s claimed ownership on message: %s", c.id, msgs[0].ID)) + return &Message{ + ID: msgs[0].ID, + Value: msgs[0].Values[msgKey], + }, nil + } + } + return nil, nil +} diff --git a/pubsub/producer.go b/pubsub/producer.go new file mode 100644 index 0000000000..37106d97ad --- /dev/null +++ b/pubsub/producer.go @@ -0,0 +1,52 @@ +package pubsub + +import ( + "context" + "fmt" + + "github.com/go-redis/redis/v8" +) + +const msgKey = "msg" + +// clientFromURL returns a redis client from url. +func clientFromURL(url string) (*redis.Client, error) { + if url == "" { + return nil, fmt.Errorf("empty redis url") + } + opts, err := redis.ParseURL(url) + if err != nil { + return nil, err + } + c := redis.NewClient(opts) + if c == nil { + return nil, fmt.Errorf("redis returned nil client") + } + return c, nil +} + +type Producer struct { + streamName string + client *redis.Client +} + +func NewProducer(streamName string, url string) (*Producer, error) { + c, err := clientFromURL(url) + if err != nil { + return nil, err + } + return &Producer{ + streamName: streamName, + client: c, + }, nil +} + +func (p *Producer) Produce(ctx context.Context, value any) error { + if _, err := p.client.XAdd(ctx, &redis.XAddArgs{ + Stream: p.streamName, + Values: map[string]any{msgKey: value}, + }).Result(); err != nil { + return fmt.Errorf("adding values to redis: %w", err) + } + return nil +} diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go new file mode 100644 index 0000000000..2bf08b6a36 --- /dev/null +++ b/pubsub/pubsub_test.go @@ -0,0 +1,207 @@ +package pubsub + +import ( + "context" + "errors" + "fmt" + "os" + "sort" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/go-redis/redis/v8" + "github.com/google/go-cmp/cmp" + "github.com/offchainlabs/nitro/util/redisutil" +) + +var ( + streamName = "validator_stream" + consumersCount = 10 + messagesCount = 100 +) + +type testConsumer struct { + consumer *Consumer + cancel context.CancelFunc +} + +func createGroup(ctx context.Context, t *testing.T, client *redis.Client) { + t.Helper() + _, err := client.XGroupCreateMkStream(ctx, streamName, "default", "$").Result() + if err != nil { + t.Fatalf("Error creating stream group: %v", err) + } +} + +func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*testConsumer) { + t.Helper() + tmpI, tmpT := KeepAliveInterval, KeepAliveTimeout + KeepAliveInterval, KeepAliveTimeout = 5*time.Millisecond, 30*time.Millisecond + t.Cleanup(func() { KeepAliveInterval, KeepAliveTimeout = tmpI, tmpT }) + + redisURL := redisutil.CreateTestRedis(ctx, t) + producer, err := NewProducer(streamName, redisURL) + if err != nil { + t.Fatalf("Error creating new producer: %v", err) + } + var ( + consumers []*testConsumer + ) + for i := 0; i < consumersCount; i++ { + consumerCtx, cancel := context.WithCancel(ctx) + c, err := NewConsumer(consumerCtx, fmt.Sprintf("consumer-%d", i), streamName, redisURL) + if err != nil { + t.Fatalf("Error creating new consumer: %v", err) + } + consumers = append(consumers, &testConsumer{ + consumer: c, + cancel: cancel, + }) + } + createGroup(ctx, t, producer.client) + return producer, consumers +} + +func messagesMap(n int) []map[string]any { + ret := make([]map[string]any, n) + for i := 0; i < n; i++ { + ret[i] = make(map[string]any) + } + return ret +} + +func TestProduce(t *testing.T) { + log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + producer, consumers := newProducerConsumers(ctx, t) + consumerCtx, cancelConsumers := context.WithTimeout(ctx, time.Second) + gotMessages := messagesMap(consumersCount) + + for idx, c := range consumers { + idx, c := idx, c.consumer + go func() { + for { + res, err := c.Consume(consumerCtx) + if err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Consume() unexpected error: %v", err) + } + return + } + gotMessages[idx][res.ID] = res.Value + c.ACK(consumerCtx, res.ID) + } + }() + } + + var want []any + for i := 0; i < messagesCount; i++ { + value := fmt.Sprintf("msg: %d", i) + want = append(want, value) + if err := producer.Produce(ctx, value); err != nil { + t.Errorf("Produce() unexpected error: %v", err) + } + } + time.Sleep(time.Second) + cancelConsumers() + got, err := mergeValues(gotMessages) + if err != nil { + t.Fatalf("mergeMaps() unexpected error: %v", err) + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) + } +} + +func TestClaimingOwnership(t *testing.T) { + log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + producer, consumers := newProducerConsumers(ctx, t) + consumerCtx, cancelConsumers := context.WithCancel(ctx) + gotMessages := messagesMap(consumersCount) + + // Consumer messages in every third consumer but don't ack them to check + // that other consumers will claim ownership on those messages. + for i := 0; i < len(consumers); i += 3 { + consumers[i].cancel() + go consumers[i].consumer.Consume(context.Background()) + } + var total atomic.Uint64 + + for idx, c := range consumers { + idx, c := idx, c.consumer + go func() { + for { + if idx%3 == 0 { + continue + } + res, err := c.Consume(consumerCtx) + if err != nil { + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + t.Errorf("Consume() unexpected error: %v", err) + continue + } + return + } + if res == nil { + continue + } + gotMessages[idx][res.ID] = res.Value + c.ACK(consumerCtx, res.ID) + total.Add(1) + } + }() + } + + var want []any + for i := 0; i < messagesCount; i++ { + value := fmt.Sprintf("msg: %d", i) + want = append(want, value) + if err := producer.Produce(ctx, value); err != nil { + t.Errorf("Produce() unexpected error: %v", err) + } + } + sort.Slice(want, func(i, j int) bool { + return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j]) + }) + + for { + if total.Load() < uint64(messagesCount) { + time.Sleep(100 * time.Millisecond) + continue + } + break + } + cancelConsumers() + got, err := mergeValues(gotMessages) + if err != nil { + t.Fatalf("mergeMaps() unexpected error: %v", err) + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) + } +} + +// mergeValues merges maps from the slice and returns their values. +// Returns and error if there exists duplicate key. +func mergeValues(messages []map[string]any) ([]any, error) { + res := make(map[string]any) + var ret []any + for _, m := range messages { + for k, v := range m { + if _, found := res[k]; found { + return nil, fmt.Errorf("duplicate key: %v", k) + } + res[k] = v + ret = append(ret, v) + } + } + sort.Slice(ret, func(i, j int) bool { + return fmt.Sprintf("%v", ret[i]) < fmt.Sprintf("%v", ret[j]) + }) + return ret, nil +} From 7abd266557ac0f281c775e5565411c5321b2b582 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 19 Mar 2024 09:34:08 +0100 Subject: [PATCH 02/33] Fix lint --- pubsub/pubsub_test.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 2bf08b6a36..8dbaa6f6e0 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -92,7 +92,9 @@ func TestProduce(t *testing.T) { return } gotMessages[idx][res.ID] = res.Value - c.ACK(consumerCtx, res.ID) + if err := c.ACK(consumerCtx, res.ID); err != nil { + t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) + } } }() } @@ -127,8 +129,13 @@ func TestClaimingOwnership(t *testing.T) { // Consumer messages in every third consumer but don't ack them to check // that other consumers will claim ownership on those messages. for i := 0; i < len(consumers); i += 3 { + i := i consumers[i].cancel() - go consumers[i].consumer.Consume(context.Background()) + go func() { + if _, err := consumers[i].consumer.Consume(context.Background()); err != nil { + t.Errorf("Error consuming message: %v", err) + } + }() } var total atomic.Uint64 @@ -151,7 +158,9 @@ func TestClaimingOwnership(t *testing.T) { continue } gotMessages[idx][res.ID] = res.Value - c.ACK(consumerCtx, res.ID) + if err := c.ACK(consumerCtx, res.ID); err != nil { + t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) + } total.Add(1) } }() From 651f26a96adaf1391246d2d704fafb470298d9f6 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 19 Mar 2024 11:18:43 +0100 Subject: [PATCH 03/33] Fix tests --- pubsub/pubsub_test.go | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 8dbaa6f6e0..753915fe8e 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -46,9 +46,7 @@ func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*test if err != nil { t.Fatalf("Error creating new producer: %v", err) } - var ( - consumers []*testConsumer - ) + var consumers []*testConsumer for i := 0; i < consumersCount; i++ { consumerCtx, cancel := context.WithCancel(ctx) c, err := NewConsumer(consumerCtx, fmt.Sprintf("consumer-%d", i), streamName, redisURL) @@ -72,6 +70,17 @@ func messagesMap(n int) []map[string]any { return ret } +func wantMessages(n int) []any { + var ret []any + for i := 0; i < n; i++ { + ret = append(ret, fmt.Sprintf("msg: %d", i)) + } + sort.Slice(ret, func(i, j int) bool { + return fmt.Sprintf("%v", ret[i]) < fmt.Sprintf("%v", ret[j]) + }) + return ret +} + func TestProduce(t *testing.T) { log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) ctx, cancel := context.WithCancel(context.Background()) @@ -91,6 +100,9 @@ func TestProduce(t *testing.T) { } return } + if res == nil { + continue + } gotMessages[idx][res.ID] = res.Value if err := c.ACK(consumerCtx, res.ID); err != nil { t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) @@ -99,10 +111,8 @@ func TestProduce(t *testing.T) { }() } - var want []any for i := 0; i < messagesCount; i++ { value := fmt.Sprintf("msg: %d", i) - want = append(want, value) if err := producer.Produce(ctx, value); err != nil { t.Errorf("Produce() unexpected error: %v", err) } @@ -113,6 +123,7 @@ func TestProduce(t *testing.T) { if err != nil { t.Fatalf("mergeMaps() unexpected error: %v", err) } + want := wantMessages(messagesCount) if diff := cmp.Diff(want, got); diff != "" { t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) } @@ -166,17 +177,12 @@ func TestClaimingOwnership(t *testing.T) { }() } - var want []any for i := 0; i < messagesCount; i++ { value := fmt.Sprintf("msg: %d", i) - want = append(want, value) if err := producer.Produce(ctx, value); err != nil { t.Errorf("Produce() unexpected error: %v", err) } } - sort.Slice(want, func(i, j int) bool { - return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j]) - }) for { if total.Load() < uint64(messagesCount) { @@ -190,6 +196,7 @@ func TestClaimingOwnership(t *testing.T) { if err != nil { t.Fatalf("mergeMaps() unexpected error: %v", err) } + want := wantMessages(messagesCount) if diff := cmp.Diff(want, got); diff != "" { t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) } From 3ae8a7246170b55e4f8b26b1f4acc76fdf999cfb Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Thu, 21 Mar 2024 17:00:58 +0100 Subject: [PATCH 04/33] Address comments --- pubsub/consumer.go | 58 +++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 2978ef06b6..7ec19d22c2 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -65,11 +65,15 @@ func (c *Consumer) keepAlive(ctx context.Context) { log.Info("Consumer polling for heartbeat updates", "id", c.id) for { if err := c.client.Set(ctx, c.keepAliveKey(), time.Now().UnixMilli(), KeepAliveTimeout).Err(); err != nil { - log.Error("Updating heardbeat", "consumer", c.id, "error", err) + l := log.Error + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + l = log.Info + } + l("Updating heardbeat", "consumer", c.id, "error", err) } select { case <-ctx.Done(): - log.Error("Error keeping alive", "error", ctx.Err()) + log.Info("Error keeping alive", "error", ctx.Err()) return case <-time.After(KeepAliveInterval): } @@ -167,32 +171,38 @@ func (c *Consumer) checkPending(ctx context.Context) (*Message, error) { if len(pendingMessages) == 0 { return nil, nil } + inactive := make(map[string]bool) for _, msg := range pendingMessages { - if !c.isConsumerAlive(ctx, msg.Consumer) { - log.Debug("Consumer is not alive", "id", msg.Consumer) - msgs, err := c.client.XClaim(ctx, &redis.XClaimArgs{ - Stream: c.streamName, - Group: c.groupName, - Consumer: c.id, - MinIdle: KeepAliveTimeout, - Messages: []string{msg.ID}, - }).Result() - if err != nil { - log.Error("Error claiming ownership on message", "id", msg.ID, "consumer", c.id, "error", err) + if inactive[msg.Consumer] { + continue + } + if c.isConsumerAlive(ctx, msg.Consumer) { + continue + } + inactive[msg.Consumer] = true + log.Info("Consumer is not alive", "id", msg.Consumer) + msgs, err := c.client.XClaim(ctx, &redis.XClaimArgs{ + Stream: c.streamName, + Group: c.groupName, + Consumer: c.id, + MinIdle: KeepAliveTimeout, + Messages: []string{msg.ID}, + }).Result() + if err != nil { + log.Error("Error claiming ownership on message", "id", msg.ID, "consumer", c.id, "error", err) + continue + } + if len(msgs) != 1 { + log.Error("Attempted to claim ownership on single messsage", "id", msg.ID, "number of received messages", len(msgs)) + if len(msgs) == 0 { continue } - if len(msgs) != 1 { - log.Error("Attempted to claim ownership on single messsage", "id", msg.ID, "number of received messages", len(msgs)) - if len(msgs) == 0 { - continue - } - } - log.Info(fmt.Sprintf("Consumer: %s claimed ownership on message: %s", c.id, msgs[0].ID)) - return &Message{ - ID: msgs[0].ID, - Value: msgs[0].Values[msgKey], - }, nil } + log.Info(fmt.Sprintf("Consumer: %s claimed ownership on message: %s", c.id, msgs[0].ID)) + return &Message{ + ID: msgs[0].ID, + Value: msgs[0].Values[msgKey], + }, nil } return nil, nil } From b28f3ac7720c47bfd09e88d236c81d03e935f65c Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Fri, 22 Mar 2024 16:15:46 +0100 Subject: [PATCH 05/33] Implement config structs for producer/consumer --- pubsub/consumer.go | 82 ++++++++++++++++++++++++++++++------------- pubsub/producer.go | 12 +++++-- pubsub/pubsub_test.go | 17 ++++++--- 3 files changed, 79 insertions(+), 32 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 7ec19d22c2..43d9925452 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -9,23 +9,38 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" "github.com/google/uuid" + "github.com/spf13/pflag" ) -var ( +const pendingMessagesKey = "lock:pending" + +type ConsumerConfig struct { // Intervals in which consumer will update heartbeat. - KeepAliveInterval = 30 * time.Second + KeepAliveInterval time.Duration `koanf:"keepalive-interval"` // Duration after which consumer is considered to be dead if heartbeat // is not updated. - KeepAliveTimeout = 5 * time.Minute - // Key for locking pending messages. - pendingMessagesKey = "lock:pending" -) + KeepAliveTimeout time.Duration `koanf:"keepalive-timeout"` + // Redis url for Redis streams and locks. + RedisURL string `koanf:"redis-url"` + // Redis stream name. + RedisStream string `koanf:"redis-stream"` +} + +func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConfig) { + f.Duration(prefix+".keepalive-interval", 30*time.Second, "interval in which consumer will perform heartbeat") + f.Duration(prefix+".keepalive-timeout", 5*time.Minute, "timeout after which consumer is considered inactive if heartbeat wasn't performed") + f.String(prefix+".redis-url", "", "redis url for redis stream") + f.String(prefix+".redis-stream", "default", "redis stream name to read from") + f.String(prefix+".redis-group", "default", "redis stream consumer group name") +} type Consumer struct { - id string - streamName string - groupName string - client *redis.Client + id string + streamName string + groupName string + client *redis.Client + keepAliveInterval time.Duration + keepAliveTimeout time.Duration } type Message struct { @@ -33,25 +48,44 @@ type Message struct { Value any } -func NewConsumer(ctx context.Context, id, streamName, url string) (*Consumer, error) { - c, err := clientFromURL(url) +func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { + c, err := clientFromURL(cfg.RedisURL) if err != nil { return nil, err } - if id == "" { - id = uuid.NewString() - } + id := uuid.NewString() consumer := &Consumer{ - id: id, - streamName: streamName, - groupName: "default", - client: c, + id: id, + streamName: cfg.RedisStream, + groupName: "default", + client: c, + keepAliveInterval: cfg.KeepAliveInterval, + keepAliveTimeout: cfg.KeepAliveTimeout, } go consumer.keepAlive(ctx) return consumer, nil } +// func NewConsumer(ctx context.Context, id, streamName, url string) (*Consumer, error) { +// c, err := clientFromURL(url) +// if err != nil { +// return nil, err +// } +// if id == "" { +// id = uuid.NewString() +// } + +// consumer := &Consumer{ +// id: id, +// streamName: streamName, +// groupName: "default", +// client: c, +// } +// go consumer.keepAlive(ctx) +// return consumer, nil +// } + func keepAliveKey(id string) string { return fmt.Sprintf("consumer:%s:heartbeat", id) } @@ -64,7 +98,7 @@ func (c *Consumer) keepAliveKey() string { func (c *Consumer) keepAlive(ctx context.Context) { log.Info("Consumer polling for heartbeat updates", "id", c.id) for { - if err := c.client.Set(ctx, c.keepAliveKey(), time.Now().UnixMilli(), KeepAliveTimeout).Err(); err != nil { + if err := c.client.Set(ctx, c.keepAliveKey(), time.Now().UnixMilli(), c.keepAliveTimeout).Err(); err != nil { l := log.Error if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { l = log.Info @@ -75,7 +109,7 @@ func (c *Consumer) keepAlive(ctx context.Context) { case <-ctx.Done(): log.Info("Error keeping alive", "error", ctx.Err()) return - case <-time.After(KeepAliveInterval): + case <-time.After(c.keepAliveTimeout): } } } @@ -128,11 +162,11 @@ func (c *Consumer) isConsumerAlive(ctx context.Context, consumerID string) bool if err != nil { return false } - return time.Now().UnixMilli()-val < 2*int64(KeepAliveTimeout.Milliseconds()) + return time.Now().UnixMilli()-val < 2*int64(c.keepAliveTimeout.Milliseconds()) } func (c *Consumer) lockPending(ctx context.Context, consumerID string) bool { - acquired, err := c.client.SetNX(ctx, pendingMessagesKey, consumerID, KeepAliveInterval).Result() + acquired, err := c.client.SetNX(ctx, pendingMessagesKey, consumerID, c.keepAliveInterval).Result() if err != nil || !acquired { return false } @@ -185,7 +219,7 @@ func (c *Consumer) checkPending(ctx context.Context) (*Message, error) { Stream: c.streamName, Group: c.groupName, Consumer: c.id, - MinIdle: KeepAliveTimeout, + MinIdle: c.keepAliveTimeout, Messages: []string{msg.ID}, }).Result() if err != nil { diff --git a/pubsub/producer.go b/pubsub/producer.go index 37106d97ad..685db110b3 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -30,13 +30,19 @@ type Producer struct { client *redis.Client } -func NewProducer(streamName string, url string) (*Producer, error) { - c, err := clientFromURL(url) +type ProducerConfig struct { + RedisURL string `koanf:"redis-url"` + // Redis stream name. + RedisStream string `koanf:"redis-stream"` +} + +func NewProducer(cfg *ProducerConfig) (*Producer, error) { + c, err := clientFromURL(cfg.RedisURL) if err != nil { return nil, err } return &Producer{ - streamName: streamName, + streamName: cfg.RedisStream, client: c, }, nil } diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 753915fe8e..1e288505ab 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -37,19 +37,26 @@ func createGroup(ctx context.Context, t *testing.T, client *redis.Client) { func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*testConsumer) { t.Helper() - tmpI, tmpT := KeepAliveInterval, KeepAliveTimeout - KeepAliveInterval, KeepAliveTimeout = 5*time.Millisecond, 30*time.Millisecond - t.Cleanup(func() { KeepAliveInterval, KeepAliveTimeout = tmpI, tmpT }) + // tmpI, tmpT := KeepAliveInterval, KeepAliveTimeout + // KeepAliveInterval, KeepAliveTimeout = 5*time.Millisecond, 30*time.Millisecond + // t.Cleanup(func() { KeepAliveInterval, KeepAliveTimeout = tmpI, tmpT }) redisURL := redisutil.CreateTestRedis(ctx, t) - producer, err := NewProducer(streamName, redisURL) + producer, err := NewProducer(&ProducerConfig{RedisURL: redisURL, RedisStream: streamName}) if err != nil { t.Fatalf("Error creating new producer: %v", err) } var consumers []*testConsumer for i := 0; i < consumersCount; i++ { consumerCtx, cancel := context.WithCancel(ctx) - c, err := NewConsumer(consumerCtx, fmt.Sprintf("consumer-%d", i), streamName, redisURL) + c, err := NewConsumer(consumerCtx, + &ConsumerConfig{ + RedisURL: redisURL, + RedisStream: streamName, + KeepAliveInterval: 5 * time.Millisecond, + KeepAliveTimeout: 30 * time.Millisecond, + }, + ) if err != nil { t.Fatalf("Error creating new consumer: %v", err) } From 675c1c245f2f328c238ff1b471f78b87ef3f6366 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Mon, 25 Mar 2024 13:41:01 +0100 Subject: [PATCH 06/33] Drop commented out code, fix test --- pubsub/consumer.go | 21 +-------------------- pubsub/pubsub_test.go | 3 ++- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 43d9925452..c016208664 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -67,25 +67,6 @@ func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { return consumer, nil } -// func NewConsumer(ctx context.Context, id, streamName, url string) (*Consumer, error) { -// c, err := clientFromURL(url) -// if err != nil { -// return nil, err -// } -// if id == "" { -// id = uuid.NewString() -// } - -// consumer := &Consumer{ -// id: id, -// streamName: streamName, -// groupName: "default", -// client: c, -// } -// go consumer.keepAlive(ctx) -// return consumer, nil -// } - func keepAliveKey(id string) string { return fmt.Sprintf("consumer:%s:heartbeat", id) } @@ -109,7 +90,7 @@ func (c *Consumer) keepAlive(ctx context.Context) { case <-ctx.Done(): log.Info("Error keeping alive", "error", ctx.Err()) return - case <-time.After(c.keepAliveTimeout): + case <-time.After(c.keepAliveInterval): } } } diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 1e288505ab..eccf723f11 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -95,10 +95,11 @@ func TestProduce(t *testing.T) { producer, consumers := newProducerConsumers(ctx, t) consumerCtx, cancelConsumers := context.WithTimeout(ctx, time.Second) gotMessages := messagesMap(consumersCount) - for idx, c := range consumers { idx, c := idx, c.consumer go func() { + // Give some time to the consumers to do their heartbeat. + time.Sleep(2 * c.keepAliveInterval) for { res, err := c.Consume(consumerCtx) if err != nil { From 0f43f60e2a33c544240a415c19eaf13107654b65 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 26 Mar 2024 20:16:00 +0100 Subject: [PATCH 07/33] Use stopwaiter instead of go primitives --- pubsub/consumer.go | 153 +++++++++--------------------------------- pubsub/producer.go | 141 +++++++++++++++++++++++++++++++++++--- pubsub/pubsub_test.go | 151 ++++++++++++++++++++--------------------- 3 files changed, 238 insertions(+), 207 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index c016208664..698e2e06f0 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -9,11 +9,10 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" "github.com/google/uuid" + "github.com/offchainlabs/nitro/util/stopwaiter" "github.com/spf13/pflag" ) -const pendingMessagesKey = "lock:pending" - type ConsumerConfig struct { // Intervals in which consumer will update heartbeat. KeepAliveInterval time.Duration `koanf:"keepalive-interval"` @@ -24,6 +23,8 @@ type ConsumerConfig struct { RedisURL string `koanf:"redis-url"` // Redis stream name. RedisStream string `koanf:"redis-stream"` + // Redis consumer group name. + RedisGroup string `koanf:"redis-group"` } func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConfig) { @@ -31,10 +32,11 @@ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConf f.Duration(prefix+".keepalive-timeout", 5*time.Minute, "timeout after which consumer is considered inactive if heartbeat wasn't performed") f.String(prefix+".redis-url", "", "redis url for redis stream") f.String(prefix+".redis-stream", "default", "redis stream name to read from") - f.String(prefix+".redis-group", "default", "redis stream consumer group name") + f.String(prefix+".redis-group", defaultGroup, "redis stream consumer group name") } type Consumer struct { + stopwaiter.StopWaiter id string streamName string groupName string @@ -53,59 +55,53 @@ func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { if err != nil { return nil, err } - id := uuid.NewString() - consumer := &Consumer{ - id: id, + id: uuid.NewString(), streamName: cfg.RedisStream, - groupName: "default", + groupName: cfg.RedisGroup, client: c, keepAliveInterval: cfg.KeepAliveInterval, keepAliveTimeout: cfg.KeepAliveTimeout, } - go consumer.keepAlive(ctx) return consumer, nil } -func keepAliveKey(id string) string { +func (c *Consumer) Start(ctx context.Context) { + c.StopWaiter.Start(ctx, c) + c.StopWaiter.CallIteratively( + func(ctx context.Context) time.Duration { + c.heartBeat(ctx) + return c.keepAliveInterval + }, + ) +} + +func (c *Consumer) StopAndWait() { + c.StopWaiter.StopAndWait() +} + +func heartBeatKey(id string) string { return fmt.Sprintf("consumer:%s:heartbeat", id) } -func (c *Consumer) keepAliveKey() string { - return keepAliveKey(c.id) +func (c *Consumer) heartBeatKey() string { + return heartBeatKey(c.id) } -// keepAlive polls in keepAliveIntervals and updates heartbeat entry for itself. -func (c *Consumer) keepAlive(ctx context.Context) { - log.Info("Consumer polling for heartbeat updates", "id", c.id) - for { - if err := c.client.Set(ctx, c.keepAliveKey(), time.Now().UnixMilli(), c.keepAliveTimeout).Err(); err != nil { - l := log.Error - if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - l = log.Info - } - l("Updating heardbeat", "consumer", c.id, "error", err) - } - select { - case <-ctx.Done(): - log.Info("Error keeping alive", "error", ctx.Err()) - return - case <-time.After(c.keepAliveInterval): +// heartBeat updates the heartBeat key indicating aliveness. +func (c *Consumer) heartBeat(ctx context.Context) { + if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), c.keepAliveTimeout).Err(); err != nil { + l := log.Error + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + l = log.Info } + l("Updating heardbeat", "consumer", c.id, "error", err) } } // Consumer first checks it there exists pending message that is claimed by // unresponsive consumer, if not then reads from the stream. func (c *Consumer) Consume(ctx context.Context) (*Message, error) { - log.Debug("Attempting to consume a message", "consumer-id", c.id) - msg, err := c.checkPending(ctx) - if err != nil { - return nil, fmt.Errorf("consumer: %v checking pending messages with unavailable consumer: %w", c.id, err) - } - if msg != nil { - return msg, nil - } res, err := c.client.XReadGroup(ctx, &redis.XReadGroupArgs{ Group: c.groupName, Consumer: c.id, @@ -127,7 +123,7 @@ func (c *Consumer) Consume(ctx context.Context) (*Message, error) { log.Debug(fmt.Sprintf("Consumer: %s consuming message: %s", c.id, res[0].Messages[0].ID)) return &Message{ ID: res[0].Messages[0].ID, - Value: res[0].Messages[0].Values[msgKey], + Value: res[0].Messages[0].Values[messageKey], }, nil } @@ -136,88 +132,3 @@ func (c *Consumer) ACK(ctx context.Context, messageID string) error { _, err := c.client.XAck(ctx, c.streamName, c.groupName, messageID).Result() return err } - -// Check if a consumer is with specified ID is alive. -func (c *Consumer) isConsumerAlive(ctx context.Context, consumerID string) bool { - val, err := c.client.Get(ctx, keepAliveKey(consumerID)).Int64() - if err != nil { - return false - } - return time.Now().UnixMilli()-val < 2*int64(c.keepAliveTimeout.Milliseconds()) -} - -func (c *Consumer) lockPending(ctx context.Context, consumerID string) bool { - acquired, err := c.client.SetNX(ctx, pendingMessagesKey, consumerID, c.keepAliveInterval).Result() - if err != nil || !acquired { - return false - } - return true -} - -func (c *Consumer) unlockPending(ctx context.Context) { - log.Debug("Releasing lock", "consumer-id", c.id) - c.client.Del(ctx, pendingMessagesKey) - -} - -// checkPending lists pending messages, and checks unavailable consumers that -// have ownership on pending message. -// If such message and consumer exists, it claims ownership on it. -func (c *Consumer) checkPending(ctx context.Context) (*Message, error) { - // Locking pending list avoid the race where two instances query pending - // list and try to claim ownership on the same message. - if !c.lockPending(ctx, c.id) { - return nil, nil - } - log.Info("Consumer acquired pending lock", "consumer=id", c.id) - defer c.unlockPending(ctx) - pendingMessages, err := c.client.XPendingExt(ctx, &redis.XPendingExtArgs{ - Stream: c.streamName, - Group: c.groupName, - Start: "-", - End: "+", - Count: 100, - }).Result() - log.Info("Pending messages", "consumer", c.id, "pendingMessages", pendingMessages, "error", err) - - if err != nil && !errors.Is(err, redis.Nil) { - return nil, fmt.Errorf("querying pending messages: %w", err) - } - if len(pendingMessages) == 0 { - return nil, nil - } - inactive := make(map[string]bool) - for _, msg := range pendingMessages { - if inactive[msg.Consumer] { - continue - } - if c.isConsumerAlive(ctx, msg.Consumer) { - continue - } - inactive[msg.Consumer] = true - log.Info("Consumer is not alive", "id", msg.Consumer) - msgs, err := c.client.XClaim(ctx, &redis.XClaimArgs{ - Stream: c.streamName, - Group: c.groupName, - Consumer: c.id, - MinIdle: c.keepAliveTimeout, - Messages: []string{msg.ID}, - }).Result() - if err != nil { - log.Error("Error claiming ownership on message", "id", msg.ID, "consumer", c.id, "error", err) - continue - } - if len(msgs) != 1 { - log.Error("Attempted to claim ownership on single messsage", "id", msg.ID, "number of received messages", len(msgs)) - if len(msgs) == 0 { - continue - } - } - log.Info(fmt.Sprintf("Consumer: %s claimed ownership on message: %s", c.id, msgs[0].ID)) - return &Message{ - ID: msgs[0].ID, - Value: msgs[0].Values[msgKey], - }, nil - } - return nil, nil -} diff --git a/pubsub/producer.go b/pubsub/producer.go index 685db110b3..202ee69810 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -2,12 +2,20 @@ package pubsub import ( "context" + "errors" "fmt" + "time" + "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" + "github.com/google/uuid" + "github.com/offchainlabs/nitro/util/stopwaiter" ) -const msgKey = "msg" +const ( + messageKey = "msg" + defaultGroup = "default_consumer_group" +) // clientFromURL returns a redis client from url. func clientFromURL(url string) (*redis.Client, error) { @@ -26,14 +34,30 @@ func clientFromURL(url string) (*redis.Client, error) { } type Producer struct { - streamName string - client *redis.Client + stopwaiter.StopWaiter + id string + streamName string + client *redis.Client + groupName string + checkPendingInterval time.Duration + keepAliveInterval time.Duration + keepAliveTimeout time.Duration } type ProducerConfig struct { RedisURL string `koanf:"redis-url"` // Redis stream name. RedisStream string `koanf:"redis-stream"` + // Interval duration in which producer checks for pending messages delivered + // to the consumers that are currently inactive. + CheckPendingInterval time.Duration `koanf:"check-pending-interval"` + // Intervals in which consumer will update heartbeat. + KeepAliveInterval time.Duration `koanf:"keepalive-interval"` + // Duration after which consumer is considered to be dead if heartbeat + // is not updated. + KeepAliveTimeout time.Duration `koanf:"keepalive-timeout"` + // Redis consumer group name. + RedisGroup string `koanf:"redis-group"` } func NewProducer(cfg *ProducerConfig) (*Producer, error) { @@ -42,17 +66,112 @@ func NewProducer(cfg *ProducerConfig) (*Producer, error) { return nil, err } return &Producer{ - streamName: cfg.RedisStream, - client: c, + id: uuid.NewString(), + streamName: cfg.RedisStream, + client: c, + groupName: cfg.RedisGroup, + checkPendingInterval: cfg.CheckPendingInterval, + keepAliveInterval: cfg.KeepAliveInterval, + keepAliveTimeout: cfg.KeepAliveTimeout, }, nil } -func (p *Producer) Produce(ctx context.Context, value any) error { - if _, err := p.client.XAdd(ctx, &redis.XAddArgs{ - Stream: p.streamName, - Values: map[string]any{msgKey: value}, - }).Result(); err != nil { - return fmt.Errorf("adding values to redis: %w", err) +func (p *Producer) Start(ctx context.Context) { + p.StopWaiter.Start(ctx, p) + p.StopWaiter.CallIteratively( + func(ctx context.Context) time.Duration { + msgs, err := p.checkPending(ctx) + if err != nil { + log.Error("Checking pending messages", "error", err) + return p.checkPendingInterval + } + if len(msgs) == 0 { + return p.checkPendingInterval + } + var acked []any + for _, msg := range msgs { + if _, err := p.client.XAck(ctx, p.streamName, p.groupName, msg.ID).Result(); err != nil { + log.Error("ACKing message", "error", err) + continue + } + acked = append(acked, msg.Value) + } + // Only re-insert messages that were removed the the pending list first. + if err := p.Produce(ctx, acked); err != nil { + log.Error("Re-inserting pending messages with inactive consumers", "error", err) + } + return p.checkPendingInterval + }, + ) +} + +func (p *Producer) Produce(ctx context.Context, values ...any) error { + if len(values) == 0 { + return nil + } + for _, value := range values { + log.Info("anodar producing", "value", value) + if _, err := p.client.XAdd(ctx, &redis.XAddArgs{ + Stream: p.streamName, + Values: map[string]any{messageKey: value}, + }).Result(); err != nil { + return fmt.Errorf("adding values to redis: %w", err) + } } return nil } + +// Check if a consumer is with specified ID is alive. +func (p *Producer) isConsumerAlive(ctx context.Context, consumerID string) bool { + val, err := p.client.Get(ctx, heartBeatKey(consumerID)).Int64() + if err != nil { + return false + } + return time.Now().UnixMilli()-val < 2*int64(p.keepAliveTimeout.Milliseconds()) +} + +func (p *Producer) checkPending(ctx context.Context) ([]*Message, error) { + pendingMessages, err := p.client.XPendingExt(ctx, &redis.XPendingExtArgs{ + Stream: p.streamName, + Group: p.groupName, + Start: "-", + End: "+", + Count: 100, + }).Result() + + if err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("querying pending messages: %w", err) + } + if len(pendingMessages) == 0 { + return nil, nil + } + // IDs of the pending messages with inactive consumers. + var ids []string + inactive := make(map[string]bool) + for _, msg := range pendingMessages { + if inactive[msg.Consumer] || p.isConsumerAlive(ctx, msg.Consumer) { + continue + } + inactive[msg.Consumer] = true + ids = append(ids, msg.ID) + } + log.Info("Attempting to claim", "messages", ids) + claimedMsgs, err := p.client.XClaim(ctx, &redis.XClaimArgs{ + Stream: p.streamName, + Group: p.groupName, + Consumer: p.id, + MinIdle: p.keepAliveTimeout, + Messages: ids, + }).Result() + if err != nil { + return nil, fmt.Errorf("claiming ownership on messages: %v, error: %v", ids, err) + } + var res []*Message + for _, msg := range claimedMsgs { + res = append(res, &Message{ + ID: msg.ID, + Value: msg.Values[messageKey], + }) + } + return res, nil +} diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index eccf723f11..f04f58593f 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -22,37 +22,36 @@ var ( messagesCount = 100 ) -type testConsumer struct { - consumer *Consumer - cancel context.CancelFunc -} - func createGroup(ctx context.Context, t *testing.T, client *redis.Client) { t.Helper() - _, err := client.XGroupCreateMkStream(ctx, streamName, "default", "$").Result() + _, err := client.XGroupCreateMkStream(ctx, streamName, defaultGroup, "$").Result() if err != nil { t.Fatalf("Error creating stream group: %v", err) } } -func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*testConsumer) { +func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*Consumer) { t.Helper() - // tmpI, tmpT := KeepAliveInterval, KeepAliveTimeout - // KeepAliveInterval, KeepAliveTimeout = 5*time.Millisecond, 30*time.Millisecond - // t.Cleanup(func() { KeepAliveInterval, KeepAliveTimeout = tmpI, tmpT }) - redisURL := redisutil.CreateTestRedis(ctx, t) - producer, err := NewProducer(&ProducerConfig{RedisURL: redisURL, RedisStream: streamName}) + producer, err := NewProducer( + &ProducerConfig{ + RedisURL: redisURL, + RedisStream: streamName, + RedisGroup: defaultGroup, + CheckPendingInterval: 10 * time.Millisecond, + KeepAliveInterval: 5 * time.Millisecond, + KeepAliveTimeout: 20 * time.Millisecond, + }) if err != nil { t.Fatalf("Error creating new producer: %v", err) } - var consumers []*testConsumer + var consumers []*Consumer for i := 0; i < consumersCount; i++ { - consumerCtx, cancel := context.WithCancel(ctx) - c, err := NewConsumer(consumerCtx, + c, err := NewConsumer(ctx, &ConsumerConfig{ RedisURL: redisURL, RedisStream: streamName, + RedisGroup: defaultGroup, KeepAliveInterval: 5 * time.Millisecond, KeepAliveTimeout: 30 * time.Millisecond, }, @@ -60,10 +59,7 @@ func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*test if err != nil { t.Fatalf("Error creating new consumer: %v", err) } - consumers = append(consumers, &testConsumer{ - consumer: c, - cancel: cancel, - }) + consumers = append(consumers, c) } createGroup(ctx, t, producer.client) return producer, consumers @@ -89,34 +85,32 @@ func wantMessages(n int) []any { } func TestProduce(t *testing.T) { - log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + // log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) + ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) - consumerCtx, cancelConsumers := context.WithTimeout(ctx, time.Second) gotMessages := messagesMap(consumersCount) for idx, c := range consumers { - idx, c := idx, c.consumer - go func() { - // Give some time to the consumers to do their heartbeat. - time.Sleep(2 * c.keepAliveInterval) - for { - res, err := c.Consume(consumerCtx) - if err != nil { - if !errors.Is(err, context.DeadlineExceeded) { - t.Errorf("Consume() unexpected error: %v", err) + idx, c := idx, c + c.Start(ctx) + c.StopWaiter.LaunchThread( + func(ctx context.Context) { + for { + res, err := c.Consume(ctx) + if err != nil { + if !errors.Is(err, context.Canceled) { + t.Errorf("Consume() unexpected error: %v", err) + } + return + } + if res == nil { + continue + } + gotMessages[idx][res.ID] = res.Value + if err := c.ACK(ctx, res.ID); err != nil { + t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) } - return - } - if res == nil { - continue - } - gotMessages[idx][res.ID] = res.Value - if err := c.ACK(consumerCtx, res.ID); err != nil { - t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) } - } - }() + }) } for i := 0; i < messagesCount; i++ { @@ -125,8 +119,12 @@ func TestProduce(t *testing.T) { t.Errorf("Produce() unexpected error: %v", err) } } - time.Sleep(time.Second) - cancelConsumers() + producer.StopWaiter.StopAndWait() + time.Sleep(50 * time.Millisecond) + for _, c := range consumers { + c.StopAndWait() + } + got, err := mergeValues(gotMessages) if err != nil { t.Fatalf("mergeMaps() unexpected error: %v", err) @@ -139,50 +137,51 @@ func TestProduce(t *testing.T) { func TestClaimingOwnership(t *testing.T) { log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) - consumerCtx, cancelConsumers := context.WithCancel(ctx) + producer.Start(ctx) gotMessages := messagesMap(consumersCount) // Consumer messages in every third consumer but don't ack them to check // that other consumers will claim ownership on those messages. for i := 0; i < len(consumers); i += 3 { i := i - consumers[i].cancel() - go func() { - if _, err := consumers[i].consumer.Consume(context.Background()); err != nil { - t.Errorf("Error consuming message: %v", err) - } - }() + if _, err := consumers[i].Consume(ctx); err != nil { + t.Errorf("Error consuming message: %v", err) + } + consumers[i].StopAndWait() } var total atomic.Uint64 for idx, c := range consumers { - idx, c := idx, c.consumer - go func() { - for { - if idx%3 == 0 { - continue - } - res, err := c.Consume(consumerCtx) - if err != nil { - if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - t.Errorf("Consume() unexpected error: %v", err) + idx, c := idx, c + if !c.StopWaiter.Started() { + c.Start(ctx) + } + c.StopWaiter.LaunchThread( + func(ctx context.Context) { + for { + if idx%3 == 0 { continue } - return - } - if res == nil { - continue - } - gotMessages[idx][res.ID] = res.Value - if err := c.ACK(consumerCtx, res.ID); err != nil { - t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) + res, err := c.Consume(ctx) + if err != nil { + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + t.Errorf("Consume() unexpected error: %v", err) + continue + } + return + } + if res == nil { + continue + } + gotMessages[idx][res.ID] = res.Value + if err := c.ACK(ctx, res.ID); err != nil { + t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) + } + total.Add(1) } - total.Add(1) - } - }() + }) } for i := 0; i < messagesCount; i++ { @@ -199,7 +198,9 @@ func TestClaimingOwnership(t *testing.T) { } break } - cancelConsumers() + for _, c := range consumers { + c.StopWaiter.StopAndWait() + } got, err := mergeValues(gotMessages) if err != nil { t.Fatalf("mergeMaps() unexpected error: %v", err) From 046fb251017b6b5dad5e020b190f2081c4d88890 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 26 Mar 2024 20:22:54 +0100 Subject: [PATCH 08/33] Fix linter error --- pubsub/producer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pubsub/producer.go b/pubsub/producer.go index 202ee69810..c80d641a5a 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -164,7 +164,7 @@ func (p *Producer) checkPending(ctx context.Context) ([]*Message, error) { Messages: ids, }).Result() if err != nil { - return nil, fmt.Errorf("claiming ownership on messages: %v, error: %v", ids, err) + return nil, fmt.Errorf("claiming ownership on messages: %v, error: %w", ids, err) } var res []*Message for _, msg := range claimedMsgs { From a21e46a2e65a72acb75ff1f13b72c7c2ee5e3f27 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 26 Mar 2024 20:31:09 +0100 Subject: [PATCH 09/33] Drop logging in tests --- pubsub/producer.go | 2 +- pubsub/pubsub_test.go | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/pubsub/producer.go b/pubsub/producer.go index c80d641a5a..ad5b44e1ec 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -127,7 +127,7 @@ func (p *Producer) isConsumerAlive(ctx context.Context, consumerID string) bool if err != nil { return false } - return time.Now().UnixMilli()-val < 2*int64(p.keepAliveTimeout.Milliseconds()) + return time.Now().UnixMilli()-val < int64(p.keepAliveTimeout.Milliseconds()) } func (p *Producer) checkPending(ctx context.Context) ([]*Message, error) { diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index f04f58593f..e34b107e2a 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -4,13 +4,11 @@ import ( "context" "errors" "fmt" - "os" "sort" "sync/atomic" "testing" "time" - "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" "github.com/google/go-cmp/cmp" "github.com/offchainlabs/nitro/util/redisutil" @@ -85,7 +83,6 @@ func wantMessages(n int) []any { } func TestProduce(t *testing.T) { - // log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) gotMessages := messagesMap(consumersCount) @@ -136,7 +133,6 @@ func TestProduce(t *testing.T) { } func TestClaimingOwnership(t *testing.T) { - log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) producer.Start(ctx) From 07e4efe0864ccd49e45f2c8eb4dd1ca194588fa3 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Wed, 27 Mar 2024 12:34:50 +0100 Subject: [PATCH 10/33] Address comments --- pubsub/consumer.go | 42 ++++++++++++++++++---------------- pubsub/producer.go | 53 ++++++++++++++++++++----------------------- pubsub/pubsub_test.go | 1 - 3 files changed, 47 insertions(+), 49 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 698e2e06f0..133cf8fbbf 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -37,12 +37,13 @@ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConf type Consumer struct { stopwaiter.StopWaiter - id string - streamName string - groupName string - client *redis.Client - keepAliveInterval time.Duration - keepAliveTimeout time.Duration + id string + client *redis.Client + cfg *ConsumerConfig + // streamName string + // groupName string + // keepAliveInterval time.Duration + // keepAliveTimeout time.Duration } type Message struct { @@ -56,12 +57,13 @@ func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { return nil, err } consumer := &Consumer{ - id: uuid.NewString(), - streamName: cfg.RedisStream, - groupName: cfg.RedisGroup, - client: c, - keepAliveInterval: cfg.KeepAliveInterval, - keepAliveTimeout: cfg.KeepAliveTimeout, + id: uuid.NewString(), + client: c, + cfg: cfg, + // streamName: cfg.RedisStream, + // groupName: cfg.RedisGroup, + // keepAliveInterval: cfg.KeepAliveInterval, + // keepAliveTimeout: cfg.KeepAliveTimeout, } return consumer, nil } @@ -71,7 +73,7 @@ func (c *Consumer) Start(ctx context.Context) { c.StopWaiter.CallIteratively( func(ctx context.Context) time.Duration { c.heartBeat(ctx) - return c.keepAliveInterval + return c.cfg.KeepAliveInterval }, ) } @@ -90,10 +92,10 @@ func (c *Consumer) heartBeatKey() string { // heartBeat updates the heartBeat key indicating aliveness. func (c *Consumer) heartBeat(ctx context.Context) { - if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), c.keepAliveTimeout).Err(); err != nil { - l := log.Error - if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - l = log.Info + if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), 2*c.cfg.KeepAliveTimeout).Err(); err != nil { + l := log.Info + if ctx.Err() != nil { + l = log.Error } l("Updating heardbeat", "consumer", c.id, "error", err) } @@ -103,11 +105,11 @@ func (c *Consumer) heartBeat(ctx context.Context) { // unresponsive consumer, if not then reads from the stream. func (c *Consumer) Consume(ctx context.Context) (*Message, error) { res, err := c.client.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: c.groupName, + Group: c.cfg.RedisGroup, Consumer: c.id, // Receive only messages that were never delivered to any other consumer, // that is, only new messages. - Streams: []string{c.streamName, ">"}, + Streams: []string{c.cfg.RedisStream, ">"}, Count: 1, Block: time.Millisecond, // 0 seems to block the read instead of immediately returning }).Result() @@ -129,6 +131,6 @@ func (c *Consumer) Consume(ctx context.Context) (*Message, error) { func (c *Consumer) ACK(ctx context.Context, messageID string) error { log.Info("ACKing message", "consumer-id", c.id, "message-sid", messageID) - _, err := c.client.XAck(ctx, c.streamName, c.groupName, messageID).Result() + _, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result() return err } diff --git a/pubsub/producer.go b/pubsub/producer.go index ad5b44e1ec..3ece2a7f6e 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -35,13 +35,13 @@ func clientFromURL(url string) (*redis.Client, error) { type Producer struct { stopwaiter.StopWaiter - id string - streamName string - client *redis.Client - groupName string - checkPendingInterval time.Duration - keepAliveInterval time.Duration - keepAliveTimeout time.Duration + id string + client *redis.Client + cfg *ProducerConfig + // streamName string + // groupName string + // checkPendingInterval time.Duration + // keepAliveTimeout time.Duration } type ProducerConfig struct { @@ -51,8 +51,6 @@ type ProducerConfig struct { // Interval duration in which producer checks for pending messages delivered // to the consumers that are currently inactive. CheckPendingInterval time.Duration `koanf:"check-pending-interval"` - // Intervals in which consumer will update heartbeat. - KeepAliveInterval time.Duration `koanf:"keepalive-interval"` // Duration after which consumer is considered to be dead if heartbeat // is not updated. KeepAliveTimeout time.Duration `koanf:"keepalive-timeout"` @@ -66,13 +64,13 @@ func NewProducer(cfg *ProducerConfig) (*Producer, error) { return nil, err } return &Producer{ - id: uuid.NewString(), - streamName: cfg.RedisStream, - client: c, - groupName: cfg.RedisGroup, - checkPendingInterval: cfg.CheckPendingInterval, - keepAliveInterval: cfg.KeepAliveInterval, - keepAliveTimeout: cfg.KeepAliveTimeout, + id: uuid.NewString(), + client: c, + cfg: cfg, + // streamName: cfg.RedisStream, + // groupName: cfg.RedisGroup, + // checkPendingInterval: cfg.CheckPendingInterval, + // keepAliveTimeout: cfg.KeepAliveTimeout, }, nil } @@ -83,14 +81,14 @@ func (p *Producer) Start(ctx context.Context) { msgs, err := p.checkPending(ctx) if err != nil { log.Error("Checking pending messages", "error", err) - return p.checkPendingInterval + return p.cfg.CheckPendingInterval } if len(msgs) == 0 { - return p.checkPendingInterval + return p.cfg.CheckPendingInterval } var acked []any for _, msg := range msgs { - if _, err := p.client.XAck(ctx, p.streamName, p.groupName, msg.ID).Result(); err != nil { + if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil { log.Error("ACKing message", "error", err) continue } @@ -100,7 +98,7 @@ func (p *Producer) Start(ctx context.Context) { if err := p.Produce(ctx, acked); err != nil { log.Error("Re-inserting pending messages with inactive consumers", "error", err) } - return p.checkPendingInterval + return p.cfg.CheckPendingInterval }, ) } @@ -110,9 +108,8 @@ func (p *Producer) Produce(ctx context.Context, values ...any) error { return nil } for _, value := range values { - log.Info("anodar producing", "value", value) if _, err := p.client.XAdd(ctx, &redis.XAddArgs{ - Stream: p.streamName, + Stream: p.cfg.RedisStream, Values: map[string]any{messageKey: value}, }).Result(); err != nil { return fmt.Errorf("adding values to redis: %w", err) @@ -127,13 +124,13 @@ func (p *Producer) isConsumerAlive(ctx context.Context, consumerID string) bool if err != nil { return false } - return time.Now().UnixMilli()-val < int64(p.keepAliveTimeout.Milliseconds()) + return time.Now().UnixMilli()-val < int64(p.cfg.KeepAliveTimeout.Milliseconds()) } func (p *Producer) checkPending(ctx context.Context) ([]*Message, error) { pendingMessages, err := p.client.XPendingExt(ctx, &redis.XPendingExtArgs{ - Stream: p.streamName, - Group: p.groupName, + Stream: p.cfg.RedisStream, + Group: p.cfg.RedisGroup, Start: "-", End: "+", Count: 100, @@ -157,10 +154,10 @@ func (p *Producer) checkPending(ctx context.Context) ([]*Message, error) { } log.Info("Attempting to claim", "messages", ids) claimedMsgs, err := p.client.XClaim(ctx, &redis.XClaimArgs{ - Stream: p.streamName, - Group: p.groupName, + Stream: p.cfg.RedisStream, + Group: p.cfg.RedisGroup, Consumer: p.id, - MinIdle: p.keepAliveTimeout, + MinIdle: p.cfg.KeepAliveTimeout, Messages: ids, }).Result() if err != nil { diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index e34b107e2a..04b781e124 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -37,7 +37,6 @@ func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*Cons RedisStream: streamName, RedisGroup: defaultGroup, CheckPendingInterval: 10 * time.Millisecond, - KeepAliveInterval: 5 * time.Millisecond, KeepAliveTimeout: 20 * time.Millisecond, }) if err != nil { From 79411f9a44b4fc64e50b7a8ba3a8ee696baf0a23 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Wed, 27 Mar 2024 12:35:22 +0100 Subject: [PATCH 11/33] Drop commented out code --- pubsub/consumer.go | 8 -------- pubsub/producer.go | 8 -------- 2 files changed, 16 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 133cf8fbbf..86add35b5b 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -40,10 +40,6 @@ type Consumer struct { id string client *redis.Client cfg *ConsumerConfig - // streamName string - // groupName string - // keepAliveInterval time.Duration - // keepAliveTimeout time.Duration } type Message struct { @@ -60,10 +56,6 @@ func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { id: uuid.NewString(), client: c, cfg: cfg, - // streamName: cfg.RedisStream, - // groupName: cfg.RedisGroup, - // keepAliveInterval: cfg.KeepAliveInterval, - // keepAliveTimeout: cfg.KeepAliveTimeout, } return consumer, nil } diff --git a/pubsub/producer.go b/pubsub/producer.go index 3ece2a7f6e..1956f6d405 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -38,10 +38,6 @@ type Producer struct { id string client *redis.Client cfg *ProducerConfig - // streamName string - // groupName string - // checkPendingInterval time.Duration - // keepAliveTimeout time.Duration } type ProducerConfig struct { @@ -67,10 +63,6 @@ func NewProducer(cfg *ProducerConfig) (*Producer, error) { id: uuid.NewString(), client: c, cfg: cfg, - // streamName: cfg.RedisStream, - // groupName: cfg.RedisGroup, - // checkPendingInterval: cfg.CheckPendingInterval, - // keepAliveTimeout: cfg.KeepAliveTimeout, }, nil } From 862289cbc06ca3ffb881f0106a9dfae8303039d0 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Wed, 27 Mar 2024 12:55:41 +0100 Subject: [PATCH 12/33] Use redisutil package for creating redis client --- pubsub/consumer.go | 8 ++++++-- pubsub/producer.go | 24 ++++++------------------ pubsub/pubsub_test.go | 2 +- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 86add35b5b..6eea541b22 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -9,6 +9,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" "github.com/google/uuid" + "github.com/offchainlabs/nitro/util/redisutil" "github.com/offchainlabs/nitro/util/stopwaiter" "github.com/spf13/pflag" ) @@ -38,7 +39,7 @@ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConf type Consumer struct { stopwaiter.StopWaiter id string - client *redis.Client + client redis.UniversalClient cfg *ConsumerConfig } @@ -48,7 +49,10 @@ type Message struct { } func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { - c, err := clientFromURL(cfg.RedisURL) + if cfg.RedisURL == "" { + return nil, fmt.Errorf("redis url cannot be empty") + } + c, err := redisutil.RedisClientFromURL(cfg.RedisURL) if err != nil { return nil, err } diff --git a/pubsub/producer.go b/pubsub/producer.go index 1956f6d405..72dec203c5 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -9,6 +9,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" "github.com/google/uuid" + "github.com/offchainlabs/nitro/util/redisutil" "github.com/offchainlabs/nitro/util/stopwaiter" ) @@ -17,26 +18,10 @@ const ( defaultGroup = "default_consumer_group" ) -// clientFromURL returns a redis client from url. -func clientFromURL(url string) (*redis.Client, error) { - if url == "" { - return nil, fmt.Errorf("empty redis url") - } - opts, err := redis.ParseURL(url) - if err != nil { - return nil, err - } - c := redis.NewClient(opts) - if c == nil { - return nil, fmt.Errorf("redis returned nil client") - } - return c, nil -} - type Producer struct { stopwaiter.StopWaiter id string - client *redis.Client + client redis.UniversalClient cfg *ProducerConfig } @@ -55,7 +40,10 @@ type ProducerConfig struct { } func NewProducer(cfg *ProducerConfig) (*Producer, error) { - c, err := clientFromURL(cfg.RedisURL) + if cfg.RedisURL == "" { + return nil, fmt.Errorf("empty redis url") + } + c, err := redisutil.RedisClientFromURL(cfg.RedisURL) if err != nil { return nil, err } diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 04b781e124..778fae6995 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -20,7 +20,7 @@ var ( messagesCount = 100 ) -func createGroup(ctx context.Context, t *testing.T, client *redis.Client) { +func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient) { t.Helper() _, err := client.XGroupCreateMkStream(ctx, streamName, defaultGroup, "$").Result() if err != nil { From 260002243c606339eb6bad1a1ae9ecbc945b49d9 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Wed, 27 Mar 2024 16:48:14 +0100 Subject: [PATCH 13/33] Implement returning responses as container.Promise --- pubsub/consumer.go | 11 ++++++ pubsub/producer.go | 81 ++++++++++++++++++++++++++++++++----------- pubsub/pubsub_test.go | 78 +++++++++++++++++++++++++++++++++++------ 3 files changed, 139 insertions(+), 31 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 6eea541b22..38cb6031fa 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -36,6 +36,8 @@ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConf f.String(prefix+".redis-group", defaultGroup, "redis stream consumer group name") } +// Consumer implements a consumer for redis stream provides heartbeat to +// indicate it is alive. type Consumer struct { stopwaiter.StopWaiter id string @@ -64,6 +66,7 @@ func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { return consumer, nil } +// Start starts the consumer to iteratively perform heartbeat in configured intervals. func (c *Consumer) Start(ctx context.Context) { c.StopWaiter.Start(ctx, c) c.StopWaiter.CallIteratively( @@ -130,3 +133,11 @@ func (c *Consumer) ACK(ctx context.Context, messageID string) error { _, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result() return err } + +func (c *Consumer) SetResult(ctx context.Context, messageID string, result string) error { + acquired, err := c.client.SetNX(ctx, messageID, result, c.cfg.KeepAliveTimeout).Result() + if err != nil || !acquired { + return fmt.Errorf("setting result for message: %v, error: %w", messageID, err) + } + return nil +} diff --git a/pubsub/producer.go b/pubsub/producer.go index 72dec203c5..7ac089b3df 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -4,11 +4,13 @@ import ( "context" "errors" "fmt" + "sync" "time" "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" "github.com/google/uuid" + "github.com/offchainlabs/nitro/util/containers" "github.com/offchainlabs/nitro/util/redisutil" "github.com/offchainlabs/nitro/util/stopwaiter" ) @@ -23,6 +25,9 @@ type Producer struct { id string client redis.UniversalClient cfg *ProducerConfig + + promisesLock sync.RWMutex + promises map[string]*containers.Promise[any] } type ProducerConfig struct { @@ -35,22 +40,25 @@ type ProducerConfig struct { // Duration after which consumer is considered to be dead if heartbeat // is not updated. KeepAliveTimeout time.Duration `koanf:"keepalive-timeout"` + // Interval duration for checking the result set by consumers. + CheckResultInterval time.Duration `koanf:"check-result-interval"` // Redis consumer group name. RedisGroup string `koanf:"redis-group"` } func NewProducer(cfg *ProducerConfig) (*Producer, error) { if cfg.RedisURL == "" { - return nil, fmt.Errorf("empty redis url") + return nil, fmt.Errorf("redis url cannot be empty") } c, err := redisutil.RedisClientFromURL(cfg.RedisURL) if err != nil { return nil, err } return &Producer{ - id: uuid.NewString(), - client: c, - cfg: cfg, + id: uuid.NewString(), + client: c, + cfg: cfg, + promises: make(map[string]*containers.Promise[any]), }, nil } @@ -66,36 +74,67 @@ func (p *Producer) Start(ctx context.Context) { if len(msgs) == 0 { return p.cfg.CheckPendingInterval } - var acked []any + acked := make(map[string]any) for _, msg := range msgs { if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil { log.Error("ACKing message", "error", err) continue } - acked = append(acked, msg.Value) + acked[msg.ID] = msg.Value } - // Only re-insert messages that were removed the the pending list first. - if err := p.Produce(ctx, acked); err != nil { - log.Error("Re-inserting pending messages with inactive consumers", "error", err) + for k, v := range acked { + // Only re-insert messages that were removed the the pending list first. + _, err := p.reproduce(ctx, v, k) + if err != nil { + log.Error("Re-inserting pending messages with inactive consumers", "error", err) + } } return p.cfg.CheckPendingInterval }, ) + // Iteratively check whether result were returned for some queries. + p.StopWaiter.CallIteratively(func(ctx context.Context) time.Duration { + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + for id, promise := range p.promises { + res, err := p.client.Get(ctx, id).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + continue + } + log.Error("Error reading value in redis", "key", id, "error", err) + } + promise.Produce(res) + delete(p.promises, id) + } + return p.cfg.CheckResultInterval + }) } -func (p *Producer) Produce(ctx context.Context, values ...any) error { - if len(values) == 0 { - return nil +// reproduce is used when Producer claims ownership on the pending +// message that was sent to inactive consumer and reinserts it into the stream, +// so that seamlessly return the answer in the same promise. +func (p *Producer) reproduce(ctx context.Context, value any, oldKey string) (*containers.Promise[any], error) { + id, err := p.client.XAdd(ctx, &redis.XAddArgs{ + Stream: p.cfg.RedisStream, + Values: map[string]any{messageKey: value}, + }).Result() + if err != nil { + return nil, fmt.Errorf("adding values to redis: %w", err) } - for _, value := range values { - if _, err := p.client.XAdd(ctx, &redis.XAddArgs{ - Stream: p.cfg.RedisStream, - Values: map[string]any{messageKey: value}, - }).Result(); err != nil { - return fmt.Errorf("adding values to redis: %w", err) - } + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + promise := p.promises[oldKey] + if oldKey == "" || promise == nil { + p := containers.NewPromise[any](nil) + promise = &p } - return nil + p.promises[id] = promise + return promise, nil +} + +func (p *Producer) Produce(ctx context.Context, value any) (*containers.Promise[any], error) { + return p.reproduce(ctx, value, "") } // Check if a consumer is with specified ID is alive. @@ -126,7 +165,7 @@ func (p *Producer) checkPending(ctx context.Context) ([]*Message, error) { var ids []string inactive := make(map[string]bool) for _, msg := range pendingMessages { - if inactive[msg.Consumer] || p.isConsumerAlive(ctx, msg.Consumer) { + if !inactive[msg.Consumer] || p.isConsumerAlive(ctx, msg.Consumer) { continue } inactive[msg.Consumer] = true diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 778fae6995..23fe481777 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -11,6 +11,7 @@ import ( "github.com/go-redis/redis/v8" "github.com/google/go-cmp/cmp" + "github.com/offchainlabs/nitro/util/containers" "github.com/offchainlabs/nitro/util/redisutil" ) @@ -38,6 +39,7 @@ func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*Cons RedisGroup: defaultGroup, CheckPendingInterval: 10 * time.Millisecond, KeepAliveTimeout: 20 * time.Millisecond, + CheckResultInterval: 5 * time.Millisecond, }) if err != nil { t.Fatalf("Error creating new producer: %v", err) @@ -84,7 +86,9 @@ func wantMessages(n int) []any { func TestProduce(t *testing.T) { ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) + producer.Start(ctx) gotMessages := messagesMap(consumersCount) + wantResponses := make([][]string, len(consumers)) for idx, c := range consumers { idx, c := idx, c c.Start(ctx) @@ -105,18 +109,30 @@ func TestProduce(t *testing.T) { if err := c.ACK(ctx, res.ID); err != nil { t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) } + if err := c.SetResult(ctx, res.ID, fmt.Sprintf("result for: %v", res.ID)); err != nil { + t.Errorf("Error setting a result: %v", err) + } + wantResponses[idx] = append(wantResponses[idx], fmt.Sprintf("result for: %v", res.ID)) } }) } + var gotResponses []string + for i := 0; i < messagesCount; i++ { value := fmt.Sprintf("msg: %d", i) - if err := producer.Produce(ctx, value); err != nil { + p, err := producer.Produce(ctx, value) + if err != nil { t.Errorf("Produce() unexpected error: %v", err) } + res, err := p.Await(ctx) + if err != nil { + t.Errorf("Await() unexpected error: %v", err) + } + gotResponses = append(gotResponses, fmt.Sprintf("%v", res)) } + producer.StopWaiter.StopAndWait() - time.Sleep(50 * time.Millisecond) for _, c := range consumers { c.StopAndWait() } @@ -129,6 +145,25 @@ func TestProduce(t *testing.T) { if diff := cmp.Diff(want, got); diff != "" { t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) } + + wantResp := flatten(wantResponses) + sort.Slice(gotResponses, func(i, j int) bool { + return gotResponses[i] < gotResponses[j] + }) + if diff := cmp.Diff(wantResp, gotResponses); diff != "" { + t.Errorf("Unexpected diff in responses:\n%s\n", diff) + } +} + +func flatten(responses [][]string) []string { + var ret []string + for _, v := range responses { + ret = append(ret, v...) + } + sort.Slice(ret, func(i, j int) bool { + return ret[i] < ret[j] + }) + return ret } func TestClaimingOwnership(t *testing.T) { @@ -148,17 +183,17 @@ func TestClaimingOwnership(t *testing.T) { } var total atomic.Uint64 - for idx, c := range consumers { - idx, c := idx, c - if !c.StopWaiter.Started() { - c.Start(ctx) + wantResponses := make([][]string, len(consumers)) + for idx := 0; idx < len(consumers); idx++ { + if idx%3 == 0 { + continue } + idx, c := idx, consumers[idx] + c.Start(ctx) c.StopWaiter.LaunchThread( func(ctx context.Context) { for { - if idx%3 == 0 { - continue - } + res, err := c.Consume(ctx) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { @@ -174,16 +209,32 @@ func TestClaimingOwnership(t *testing.T) { if err := c.ACK(ctx, res.ID); err != nil { t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) } + if err := c.SetResult(ctx, res.ID, fmt.Sprintf("result for: %v", res.ID)); err != nil { + t.Errorf("Error setting a result: %v", err) + } + wantResponses[idx] = append(wantResponses[idx], fmt.Sprintf("result for: %v", res.ID)) total.Add(1) } }) } + var promises []*containers.Promise[any] for i := 0; i < messagesCount; i++ { value := fmt.Sprintf("msg: %d", i) - if err := producer.Produce(ctx, value); err != nil { + promise, err := producer.Produce(ctx, value) + if err != nil { t.Errorf("Produce() unexpected error: %v", err) } + promises = append(promises, promise) + } + var gotResponses []string + for _, p := range promises { + res, err := p.Await(ctx) + if err != nil { + t.Errorf("Await() unexpected error: %v", err) + continue + } + gotResponses = append(gotResponses, fmt.Sprintf("%v", res)) } for { @@ -204,6 +255,13 @@ func TestClaimingOwnership(t *testing.T) { if diff := cmp.Diff(want, got); diff != "" { t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) } + WantResp := flatten(wantResponses) + sort.Slice(gotResponses, func(i, j int) bool { + return gotResponses[i] < gotResponses[j] + }) + if diff := cmp.Diff(WantResp, gotResponses); diff != "" { + t.Errorf("Unexpected diff in responses:\n%s\n", diff) + } } // mergeValues merges maps from the slice and returns their values. From eb6e63be3c22a7d56bdaddc31d765c6a86b510df Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Wed, 27 Mar 2024 18:29:19 +0100 Subject: [PATCH 14/33] Add generics to the producer/consumer --- pubsub/consumer.go | 37 ++++++++++++++++++++------------ pubsub/producer.go | 50 ++++++++++++++++++++++++++++--------------- pubsub/pubsub_test.go | 43 +++++++++++++++++++++++++++---------- 3 files changed, 88 insertions(+), 42 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 38cb6031fa..b0a19c9a41 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -38,19 +38,19 @@ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConf // Consumer implements a consumer for redis stream provides heartbeat to // indicate it is alive. -type Consumer struct { +type Consumer[T Marshallable[T]] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ConsumerConfig } -type Message struct { +type Message[T Marshallable[T]] struct { ID string - Value any + Value T } -func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { +func NewConsumer[T Marshallable[T]](ctx context.Context, cfg *ConsumerConfig) (*Consumer[T], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -58,7 +58,7 @@ func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { if err != nil { return nil, err } - consumer := &Consumer{ + consumer := &Consumer[T]{ id: uuid.NewString(), client: c, cfg: cfg, @@ -67,7 +67,7 @@ func NewConsumer(ctx context.Context, cfg *ConsumerConfig) (*Consumer, error) { } // Start starts the consumer to iteratively perform heartbeat in configured intervals. -func (c *Consumer) Start(ctx context.Context) { +func (c *Consumer[T]) Start(ctx context.Context) { c.StopWaiter.Start(ctx, c) c.StopWaiter.CallIteratively( func(ctx context.Context) time.Duration { @@ -77,7 +77,7 @@ func (c *Consumer) Start(ctx context.Context) { ) } -func (c *Consumer) StopAndWait() { +func (c *Consumer[T]) StopAndWait() { c.StopWaiter.StopAndWait() } @@ -85,12 +85,12 @@ func heartBeatKey(id string) string { return fmt.Sprintf("consumer:%s:heartbeat", id) } -func (c *Consumer) heartBeatKey() string { +func (c *Consumer[T]) heartBeatKey() string { return heartBeatKey(c.id) } // heartBeat updates the heartBeat key indicating aliveness. -func (c *Consumer) heartBeat(ctx context.Context) { +func (c *Consumer[T]) heartBeat(ctx context.Context) { if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), 2*c.cfg.KeepAliveTimeout).Err(); err != nil { l := log.Info if ctx.Err() != nil { @@ -102,7 +102,7 @@ func (c *Consumer) heartBeat(ctx context.Context) { // Consumer first checks it there exists pending message that is claimed by // unresponsive consumer, if not then reads from the stream. -func (c *Consumer) Consume(ctx context.Context) (*Message, error) { +func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) { res, err := c.client.XReadGroup(ctx, &redis.XReadGroupArgs{ Group: c.cfg.RedisGroup, Consumer: c.id, @@ -122,19 +122,28 @@ func (c *Consumer) Consume(ctx context.Context) (*Message, error) { return nil, fmt.Errorf("redis returned entries: %+v, for querying single message", res) } log.Debug(fmt.Sprintf("Consumer: %s consuming message: %s", c.id, res[0].Messages[0].ID)) - return &Message{ + var ( + value = res[0].Messages[0].Values[messageKey] + tmp T + ) + val, err := tmp.Unmarshal(value) + if err != nil { + return nil, fmt.Errorf("unmarshaling value: %v, error: %v", value, err) + } + + return &Message[T]{ ID: res[0].Messages[0].ID, - Value: res[0].Messages[0].Values[messageKey], + Value: val, }, nil } -func (c *Consumer) ACK(ctx context.Context, messageID string) error { +func (c *Consumer[T]) ACK(ctx context.Context, messageID string) error { log.Info("ACKing message", "consumer-id", c.id, "message-sid", messageID) _, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result() return err } -func (c *Consumer) SetResult(ctx context.Context, messageID string, result string) error { +func (c *Consumer[T]) SetResult(ctx context.Context, messageID string, result string) error { acquired, err := c.client.SetNX(ctx, messageID, result, c.cfg.KeepAliveTimeout).Result() if err != nil || !acquired { return fmt.Errorf("setting result for message: %v, error: %w", messageID, err) diff --git a/pubsub/producer.go b/pubsub/producer.go index 7ac089b3df..29bcd09b42 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -20,14 +20,19 @@ const ( defaultGroup = "default_consumer_group" ) -type Producer struct { +type Marshallable[T any] interface { + Marshal() any + Unmarshal(val any) (T, error) +} + +type Producer[T Marshallable[T]] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ProducerConfig promisesLock sync.RWMutex - promises map[string]*containers.Promise[any] + promises map[string]*containers.Promise[T] } type ProducerConfig struct { @@ -46,7 +51,7 @@ type ProducerConfig struct { RedisGroup string `koanf:"redis-group"` } -func NewProducer(cfg *ProducerConfig) (*Producer, error) { +func NewProducer[T Marshallable[T]](cfg *ProducerConfig) (*Producer[T], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -54,15 +59,15 @@ func NewProducer(cfg *ProducerConfig) (*Producer, error) { if err != nil { return nil, err } - return &Producer{ + return &Producer[T]{ id: uuid.NewString(), client: c, cfg: cfg, - promises: make(map[string]*containers.Promise[any]), + promises: make(map[string]*containers.Promise[T]), }, nil } -func (p *Producer) Start(ctx context.Context) { +func (p *Producer[T]) Start(ctx context.Context) { p.StopWaiter.Start(ctx, p) p.StopWaiter.CallIteratively( func(ctx context.Context) time.Duration { @@ -74,7 +79,7 @@ func (p *Producer) Start(ctx context.Context) { if len(msgs) == 0 { return p.cfg.CheckPendingInterval } - acked := make(map[string]any) + acked := make(map[string]T) for _, msg := range msgs { if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil { log.Error("ACKing message", "error", err) @@ -104,7 +109,13 @@ func (p *Producer) Start(ctx context.Context) { } log.Error("Error reading value in redis", "key", id, "error", err) } - promise.Produce(res) + var tmp T + val, err := tmp.Unmarshal(res) + if err != nil { + log.Error("Error unmarshaling", "value", res, "error", err) + continue + } + promise.Produce(val) delete(p.promises, id) } return p.cfg.CheckResultInterval @@ -114,10 +125,10 @@ func (p *Producer) Start(ctx context.Context) { // reproduce is used when Producer claims ownership on the pending // message that was sent to inactive consumer and reinserts it into the stream, // so that seamlessly return the answer in the same promise. -func (p *Producer) reproduce(ctx context.Context, value any, oldKey string) (*containers.Promise[any], error) { +func (p *Producer[T]) reproduce(ctx context.Context, value T, oldKey string) (*containers.Promise[T], error) { id, err := p.client.XAdd(ctx, &redis.XAddArgs{ Stream: p.cfg.RedisStream, - Values: map[string]any{messageKey: value}, + Values: map[string]any{messageKey: value.Marshal()}, }).Result() if err != nil { return nil, fmt.Errorf("adding values to redis: %w", err) @@ -126,19 +137,19 @@ func (p *Producer) reproduce(ctx context.Context, value any, oldKey string) (*co defer p.promisesLock.Unlock() promise := p.promises[oldKey] if oldKey == "" || promise == nil { - p := containers.NewPromise[any](nil) + p := containers.NewPromise[T](nil) promise = &p } p.promises[id] = promise return promise, nil } -func (p *Producer) Produce(ctx context.Context, value any) (*containers.Promise[any], error) { +func (p *Producer[T]) Produce(ctx context.Context, value T) (*containers.Promise[T], error) { return p.reproduce(ctx, value, "") } // Check if a consumer is with specified ID is alive. -func (p *Producer) isConsumerAlive(ctx context.Context, consumerID string) bool { +func (p *Producer[T]) isConsumerAlive(ctx context.Context, consumerID string) bool { val, err := p.client.Get(ctx, heartBeatKey(consumerID)).Int64() if err != nil { return false @@ -146,7 +157,7 @@ func (p *Producer) isConsumerAlive(ctx context.Context, consumerID string) bool return time.Now().UnixMilli()-val < int64(p.cfg.KeepAliveTimeout.Milliseconds()) } -func (p *Producer) checkPending(ctx context.Context) ([]*Message, error) { +func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) { pendingMessages, err := p.client.XPendingExt(ctx, &redis.XPendingExtArgs{ Stream: p.cfg.RedisStream, Group: p.cfg.RedisGroup, @@ -182,11 +193,16 @@ func (p *Producer) checkPending(ctx context.Context) ([]*Message, error) { if err != nil { return nil, fmt.Errorf("claiming ownership on messages: %v, error: %w", ids, err) } - var res []*Message + var res []*Message[T] for _, msg := range claimedMsgs { - res = append(res, &Message{ + var tmp T + val, err := tmp.Unmarshal(msg.Values[messageKey]) + if err != nil { + return nil, fmt.Errorf("marshaling value: %v, error: %v", msg.Values[messageKey], err) + } + res = append(res, &Message[T]{ ID: msg.ID, - Value: msg.Values[messageKey], + Value: val, }) } return res, nil diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 23fe481777..944253eefa 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -21,6 +21,27 @@ var ( messagesCount = 100 ) +type testResult struct { + val string +} + +func (r *testResult) Marshal() any { + return r.val +} + +func (r *testResult) Unmarshal(val any) (*testResult, error) { + return &testResult{ + val: val.(string), + }, nil +} + +func TestMarshal(t *testing.T) { + tr := &testResult{val: "myvalue"} + val, err := tr.Unmarshal(tr.Marshal()) + t.Errorf("val: %+v, err: %v", val, err) + +} + func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient) { t.Helper() _, err := client.XGroupCreateMkStream(ctx, streamName, defaultGroup, "$").Result() @@ -29,10 +50,10 @@ func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient } } -func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*Consumer) { +func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testResult], []*Consumer[*testResult]) { t.Helper() redisURL := redisutil.CreateTestRedis(ctx, t) - producer, err := NewProducer( + producer, err := NewProducer[*testResult]( &ProducerConfig{ RedisURL: redisURL, RedisStream: streamName, @@ -44,9 +65,9 @@ func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer, []*Cons if err != nil { t.Fatalf("Error creating new producer: %v", err) } - var consumers []*Consumer + var consumers []*Consumer[*testResult] for i := 0; i < consumersCount; i++ { - c, err := NewConsumer(ctx, + c, err := NewConsumer[*testResult](ctx, &ConsumerConfig{ RedisURL: redisURL, RedisStream: streamName, @@ -105,7 +126,7 @@ func TestProduce(t *testing.T) { if res == nil { continue } - gotMessages[idx][res.ID] = res.Value + gotMessages[idx][res.ID] = res.Value.val if err := c.ACK(ctx, res.ID); err != nil { t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) } @@ -120,7 +141,7 @@ func TestProduce(t *testing.T) { var gotResponses []string for i := 0; i < messagesCount; i++ { - value := fmt.Sprintf("msg: %d", i) + value := &testResult{val: fmt.Sprintf("msg: %d", i)} p, err := producer.Produce(ctx, value) if err != nil { t.Errorf("Produce() unexpected error: %v", err) @@ -129,7 +150,7 @@ func TestProduce(t *testing.T) { if err != nil { t.Errorf("Await() unexpected error: %v", err) } - gotResponses = append(gotResponses, fmt.Sprintf("%v", res)) + gotResponses = append(gotResponses, res.val) } producer.StopWaiter.StopAndWait() @@ -205,7 +226,7 @@ func TestClaimingOwnership(t *testing.T) { if res == nil { continue } - gotMessages[idx][res.ID] = res.Value + gotMessages[idx][res.ID] = res.Value.val if err := c.ACK(ctx, res.ID); err != nil { t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) } @@ -218,9 +239,9 @@ func TestClaimingOwnership(t *testing.T) { }) } - var promises []*containers.Promise[any] + var promises []*containers.Promise[*testResult] for i := 0; i < messagesCount; i++ { - value := fmt.Sprintf("msg: %d", i) + value := &testResult{val: fmt.Sprintf("msg: %d", i)} promise, err := producer.Produce(ctx, value) if err != nil { t.Errorf("Produce() unexpected error: %v", err) @@ -234,7 +255,7 @@ func TestClaimingOwnership(t *testing.T) { t.Errorf("Await() unexpected error: %v", err) continue } - gotResponses = append(gotResponses, fmt.Sprintf("%v", res)) + gotResponses = append(gotResponses, res.val) } for { From f94c4545d57fd85122163a4d010d084ca95977b2 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Wed, 27 Mar 2024 18:33:55 +0100 Subject: [PATCH 15/33] Simplify tests --- pubsub/pubsub_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 944253eefa..ec4fb22059 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -85,16 +85,16 @@ func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testRes return producer, consumers } -func messagesMap(n int) []map[string]any { - ret := make([]map[string]any, n) +func messagesMaps(n int) []map[string]string { + ret := make([]map[string]string, n) for i := 0; i < n; i++ { - ret[i] = make(map[string]any) + ret[i] = make(map[string]string) } return ret } -func wantMessages(n int) []any { - var ret []any +func wantMessages(n int) []string { + var ret []string for i := 0; i < n; i++ { ret = append(ret, fmt.Sprintf("msg: %d", i)) } @@ -108,7 +108,7 @@ func TestProduce(t *testing.T) { ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) producer.Start(ctx) - gotMessages := messagesMap(consumersCount) + gotMessages := messagesMaps(consumersCount) wantResponses := make([][]string, len(consumers)) for idx, c := range consumers { idx, c := idx, c @@ -191,7 +191,7 @@ func TestClaimingOwnership(t *testing.T) { ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) producer.Start(ctx) - gotMessages := messagesMap(consumersCount) + gotMessages := messagesMaps(consumersCount) // Consumer messages in every third consumer but don't ack them to check // that other consumers will claim ownership on those messages. @@ -287,9 +287,9 @@ func TestClaimingOwnership(t *testing.T) { // mergeValues merges maps from the slice and returns their values. // Returns and error if there exists duplicate key. -func mergeValues(messages []map[string]any) ([]any, error) { +func mergeValues(messages []map[string]string) ([]string, error) { res := make(map[string]any) - var ret []any + var ret []string for _, m := range messages { for k, v := range m { if _, found := res[k]; found { @@ -300,7 +300,7 @@ func mergeValues(messages []map[string]any) ([]any, error) { } } sort.Slice(ret, func(i, j int) bool { - return fmt.Sprintf("%v", ret[i]) < fmt.Sprintf("%v", ret[j]) + return ret[i] < ret[j] }) return ret, nil } From 99b993990a4f0494f0f32b3a89a3334c59cebc65 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Wed, 27 Mar 2024 18:48:48 +0100 Subject: [PATCH 16/33] Fix linter error --- pubsub/consumer.go | 2 +- pubsub/producer.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index b0a19c9a41..f7c7ef1d37 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -128,7 +128,7 @@ func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) { ) val, err := tmp.Unmarshal(value) if err != nil { - return nil, fmt.Errorf("unmarshaling value: %v, error: %v", value, err) + return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err) } return &Message[T]{ diff --git a/pubsub/producer.go b/pubsub/producer.go index 29bcd09b42..79edd9eba1 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -198,7 +198,7 @@ func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) { var tmp T val, err := tmp.Unmarshal(msg.Values[messageKey]) if err != nil { - return nil, fmt.Errorf("marshaling value: %v, error: %v", msg.Values[messageKey], err) + return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err) } res = append(res, &Message[T]{ ID: msg.ID, From b183881257d177d99674c4b32bf710846b368213 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Wed, 27 Mar 2024 18:57:20 +0100 Subject: [PATCH 17/33] Drop remnant test --- pubsub/pubsub_test.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index ec4fb22059..4850166ba7 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -35,13 +35,6 @@ func (r *testResult) Unmarshal(val any) (*testResult, error) { }, nil } -func TestMarshal(t *testing.T) { - tr := &testResult{val: "myvalue"} - val, err := tr.Unmarshal(tr.Marshal()) - t.Errorf("val: %+v, err: %v", val, err) - -} - func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient) { t.Helper() _, err := client.XGroupCreateMkStream(ctx, streamName, defaultGroup, "$").Result() From 2a67624daad00767c9819cd09c0b02de1b7c298c Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Mon, 1 Apr 2024 22:07:20 +0200 Subject: [PATCH 18/33] Address comments --- pubsub/consumer.go | 53 ++++++++++++++++++---------- pubsub/producer.go | 81 ++++++++++++++++++++++++++++++------------- pubsub/pubsub_test.go | 78 ++++++++++++++++++++--------------------- 3 files changed, 128 insertions(+), 84 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index f7c7ef1d37..e013314e5b 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -15,8 +15,8 @@ import ( ) type ConsumerConfig struct { - // Intervals in which consumer will update heartbeat. - KeepAliveInterval time.Duration `koanf:"keepalive-interval"` + // Timeout of result entry in Redis. + ResponseEntryTimeout time.Duration `koanf:"response-entry-timeout"` // Duration after which consumer is considered to be dead if heartbeat // is not updated. KeepAliveTimeout time.Duration `koanf:"keepalive-timeout"` @@ -28,12 +28,26 @@ type ConsumerConfig struct { RedisGroup string `koanf:"redis-group"` } -func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConfig) { - f.Duration(prefix+".keepalive-interval", 30*time.Second, "interval in which consumer will perform heartbeat") - f.Duration(prefix+".keepalive-timeout", 5*time.Minute, "timeout after which consumer is considered inactive if heartbeat wasn't performed") - f.String(prefix+".redis-url", "", "redis url for redis stream") - f.String(prefix+".redis-stream", "default", "redis stream name to read from") - f.String(prefix+".redis-group", defaultGroup, "redis stream consumer group name") +var DefaultConsumerConfig = &ConsumerConfig{ + ResponseEntryTimeout: time.Hour, + KeepAliveTimeout: 5 * time.Minute, + RedisStream: "default", + RedisGroup: defaultGroup, +} + +var DefaultTestConsumerConfig = &ConsumerConfig{ + RedisStream: "default", + RedisGroup: defaultGroup, + ResponseEntryTimeout: time.Minute, + KeepAliveTimeout: 30 * time.Millisecond, +} + +func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet) { + f.Duration(prefix+".response-entry-timeout", DefaultConsumerConfig.ResponseEntryTimeout, "timeout for response entry") + f.Duration(prefix+".keepalive-timeout", DefaultConsumerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed") + f.String(prefix+".redis-url", DefaultConsumerConfig.RedisURL, "redis url for redis stream") + f.String(prefix+".redis-stream", DefaultConsumerConfig.RedisStream, "redis stream name to read from") + f.String(prefix+".redis-group", DefaultConsumerConfig.RedisGroup, "redis stream consumer group name") } // Consumer implements a consumer for redis stream provides heartbeat to @@ -72,7 +86,7 @@ func (c *Consumer[T]) Start(ctx context.Context) { c.StopWaiter.CallIteratively( func(ctx context.Context) time.Duration { c.heartBeat(ctx) - return c.cfg.KeepAliveInterval + return c.cfg.KeepAliveTimeout / 10 }, ) } @@ -123,10 +137,14 @@ func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) { } log.Debug(fmt.Sprintf("Consumer: %s consuming message: %s", c.id, res[0].Messages[0].ID)) var ( - value = res[0].Messages[0].Values[messageKey] - tmp T + value = res[0].Messages[0].Values[messageKey] + data, ok = (value).(string) + tmp T ) - val, err := tmp.Unmarshal(value) + if !ok { + return nil, fmt.Errorf("casting request to string: %w", err) + } + val, err := tmp.Unmarshal([]byte(data)) if err != nil { return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err) } @@ -137,16 +155,13 @@ func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) { }, nil } -func (c *Consumer[T]) ACK(ctx context.Context, messageID string) error { - log.Info("ACKing message", "consumer-id", c.id, "message-sid", messageID) - _, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result() - return err -} - func (c *Consumer[T]) SetResult(ctx context.Context, messageID string, result string) error { - acquired, err := c.client.SetNX(ctx, messageID, result, c.cfg.KeepAliveTimeout).Result() + acquired, err := c.client.SetNX(ctx, messageID, result, c.cfg.ResponseEntryTimeout).Result() if err != nil || !acquired { return fmt.Errorf("setting result for message: %v, error: %w", messageID, err) } + if _, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result(); err != nil { + return fmt.Errorf("acking message: %v, error: %w", messageID, err) + } return nil } diff --git a/pubsub/producer.go b/pubsub/producer.go index 79edd9eba1..006b84709f 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -13,6 +13,7 @@ import ( "github.com/offchainlabs/nitro/util/containers" "github.com/offchainlabs/nitro/util/redisutil" "github.com/offchainlabs/nitro/util/stopwaiter" + "github.com/spf13/pflag" ) const ( @@ -21,18 +22,18 @@ const ( ) type Marshallable[T any] interface { - Marshal() any - Unmarshal(val any) (T, error) + Marshal() []byte + Unmarshal(val []byte) (T, error) } -type Producer[T Marshallable[T]] struct { +type Producer[Request Marshallable[Request], Response Marshallable[Response]] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ProducerConfig promisesLock sync.RWMutex - promises map[string]*containers.Promise[T] + promises map[string]*containers.Promise[Response] } type ProducerConfig struct { @@ -51,7 +52,31 @@ type ProducerConfig struct { RedisGroup string `koanf:"redis-group"` } -func NewProducer[T Marshallable[T]](cfg *ProducerConfig) (*Producer[T], error) { +var DefaultProducerConfig = &ProducerConfig{ + RedisStream: "default", + CheckPendingInterval: time.Second, + KeepAliveTimeout: 5 * time.Minute, + CheckResultInterval: 5 * time.Second, + RedisGroup: defaultGroup, +} + +var DefaultTestProducerConfig = &ProducerConfig{ + RedisStream: "default", + RedisGroup: defaultGroup, + CheckPendingInterval: 10 * time.Millisecond, + KeepAliveTimeout: 20 * time.Millisecond, + CheckResultInterval: 5 * time.Millisecond, +} + +func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) { + f.String(prefix+".redis-url", DefaultConsumerConfig.RedisURL, "redis url for redis stream") + f.Duration(prefix+".response-entry-timeout", DefaultConsumerConfig.ResponseEntryTimeout, "timeout for response entry") + f.Duration(prefix+".keepalive-timeout", DefaultConsumerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed") + f.String(prefix+".redis-stream", DefaultConsumerConfig.RedisStream, "redis stream name to read from") + f.String(prefix+".redis-group", DefaultConsumerConfig.RedisGroup, "redis stream consumer group name") +} + +func NewProducer[Request Marshallable[Request], Response Marshallable[Response]](cfg *ProducerConfig) (*Producer[Request, Response], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -59,15 +84,15 @@ func NewProducer[T Marshallable[T]](cfg *ProducerConfig) (*Producer[T], error) { if err != nil { return nil, err } - return &Producer[T]{ + return &Producer[Request, Response]{ id: uuid.NewString(), client: c, cfg: cfg, - promises: make(map[string]*containers.Promise[T]), + promises: make(map[string]*containers.Promise[Response]), }, nil } -func (p *Producer[T]) Start(ctx context.Context) { +func (p *Producer[Request, Response]) Start(ctx context.Context) { p.StopWaiter.Start(ctx, p) p.StopWaiter.CallIteratively( func(ctx context.Context) time.Duration { @@ -79,7 +104,7 @@ func (p *Producer[T]) Start(ctx context.Context) { if len(msgs) == 0 { return p.cfg.CheckPendingInterval } - acked := make(map[string]T) + acked := make(map[string]Request) for _, msg := range msgs { if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil { log.Error("ACKing message", "error", err) @@ -109,8 +134,8 @@ func (p *Producer[T]) Start(ctx context.Context) { } log.Error("Error reading value in redis", "key", id, "error", err) } - var tmp T - val, err := tmp.Unmarshal(res) + var tmp Response + val, err := tmp.Unmarshal([]byte(res)) if err != nil { log.Error("Error unmarshaling", "value", res, "error", err) continue @@ -125,7 +150,7 @@ func (p *Producer[T]) Start(ctx context.Context) { // reproduce is used when Producer claims ownership on the pending // message that was sent to inactive consumer and reinserts it into the stream, // so that seamlessly return the answer in the same promise. -func (p *Producer[T]) reproduce(ctx context.Context, value T, oldKey string) (*containers.Promise[T], error) { +func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Request, oldKey string) (*containers.Promise[Response], error) { id, err := p.client.XAdd(ctx, &redis.XAddArgs{ Stream: p.cfg.RedisStream, Values: map[string]any{messageKey: value.Marshal()}, @@ -137,19 +162,19 @@ func (p *Producer[T]) reproduce(ctx context.Context, value T, oldKey string) (*c defer p.promisesLock.Unlock() promise := p.promises[oldKey] if oldKey == "" || promise == nil { - p := containers.NewPromise[T](nil) - promise = &p + pr := containers.NewPromise[Response](nil) + promise = &pr } p.promises[id] = promise return promise, nil } -func (p *Producer[T]) Produce(ctx context.Context, value T) (*containers.Promise[T], error) { +func (p *Producer[Request, Response]) Produce(ctx context.Context, value Request) (*containers.Promise[Response], error) { return p.reproduce(ctx, value, "") } // Check if a consumer is with specified ID is alive. -func (p *Producer[T]) isConsumerAlive(ctx context.Context, consumerID string) bool { +func (p *Producer[Request, Response]) isConsumerAlive(ctx context.Context, consumerID string) bool { val, err := p.client.Get(ctx, heartBeatKey(consumerID)).Int64() if err != nil { return false @@ -157,7 +182,7 @@ func (p *Producer[T]) isConsumerAlive(ctx context.Context, consumerID string) bo return time.Now().UnixMilli()-val < int64(p.cfg.KeepAliveTimeout.Milliseconds()) } -func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) { +func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Message[Request], error) { pendingMessages, err := p.client.XPendingExt(ctx, &redis.XPendingExtArgs{ Stream: p.cfg.RedisStream, Group: p.cfg.RedisGroup, @@ -174,12 +199,16 @@ func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) { } // IDs of the pending messages with inactive consumers. var ids []string - inactive := make(map[string]bool) + active := make(map[string]bool) for _, msg := range pendingMessages { - if !inactive[msg.Consumer] || p.isConsumerAlive(ctx, msg.Consumer) { + alive, found := active[msg.Consumer] + if !found { + alive = p.isConsumerAlive(ctx, msg.Consumer) + active[msg.Consumer] = alive + } + if alive { continue } - inactive[msg.Consumer] = true ids = append(ids, msg.ID) } log.Info("Attempting to claim", "messages", ids) @@ -193,14 +222,18 @@ func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) { if err != nil { return nil, fmt.Errorf("claiming ownership on messages: %v, error: %w", ids, err) } - var res []*Message[T] + var res []*Message[Request] for _, msg := range claimedMsgs { - var tmp T - val, err := tmp.Unmarshal(msg.Values[messageKey]) + data, ok := (msg.Values[messageKey]).([]byte) + if !ok { + return nil, fmt.Errorf("casting request to bytes: %w", err) + } + var tmp Request + val, err := tmp.Unmarshal(data) if err != nil { return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err) } - res = append(res, &Message[T]{ + res = append(res, &Message[Request]{ ID: msg.ID, Value: val, }) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 4850166ba7..77f2a87914 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -16,22 +16,36 @@ import ( ) var ( - streamName = "validator_stream" + streamName = DefaultTestProducerConfig.RedisStream consumersCount = 10 messagesCount = 100 ) -type testResult struct { - val string +type testRequest struct { + request string } -func (r *testResult) Marshal() any { - return r.val +func (r *testRequest) Marshal() []byte { + return []byte(r.request) } -func (r *testResult) Unmarshal(val any) (*testResult, error) { - return &testResult{ - val: val.(string), +func (r *testRequest) Unmarshal(val []byte) (*testRequest, error) { + return &testRequest{ + request: string(val), + }, nil +} + +type testResponse struct { + response string +} + +func (r *testResponse) Marshal() []byte { + return []byte(r.response) +} + +func (r *testResponse) Unmarshal(val []byte) (*testResponse, error) { + return &testResponse{ + response: string(val), }, nil } @@ -43,32 +57,20 @@ func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient } } -func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testResult], []*Consumer[*testResult]) { +func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testRequest, *testResponse], []*Consumer[*testRequest]) { t.Helper() redisURL := redisutil.CreateTestRedis(ctx, t) - producer, err := NewProducer[*testResult]( - &ProducerConfig{ - RedisURL: redisURL, - RedisStream: streamName, - RedisGroup: defaultGroup, - CheckPendingInterval: 10 * time.Millisecond, - KeepAliveTimeout: 20 * time.Millisecond, - CheckResultInterval: 5 * time.Millisecond, - }) + defaultProdCfg := DefaultTestProducerConfig + defaultProdCfg.RedisURL = redisURL + producer, err := NewProducer[*testRequest, *testResponse](defaultProdCfg) if err != nil { t.Fatalf("Error creating new producer: %v", err) } - var consumers []*Consumer[*testResult] + defaultCfg := DefaultTestConsumerConfig + defaultCfg.RedisURL = redisURL + var consumers []*Consumer[*testRequest] for i := 0; i < consumersCount; i++ { - c, err := NewConsumer[*testResult](ctx, - &ConsumerConfig{ - RedisURL: redisURL, - RedisStream: streamName, - RedisGroup: defaultGroup, - KeepAliveInterval: 5 * time.Millisecond, - KeepAliveTimeout: 30 * time.Millisecond, - }, - ) + c, err := NewConsumer[*testRequest](ctx, defaultCfg) if err != nil { t.Fatalf("Error creating new consumer: %v", err) } @@ -119,10 +121,7 @@ func TestProduce(t *testing.T) { if res == nil { continue } - gotMessages[idx][res.ID] = res.Value.val - if err := c.ACK(ctx, res.ID); err != nil { - t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) - } + gotMessages[idx][res.ID] = res.Value.request if err := c.SetResult(ctx, res.ID, fmt.Sprintf("result for: %v", res.ID)); err != nil { t.Errorf("Error setting a result: %v", err) } @@ -134,7 +133,7 @@ func TestProduce(t *testing.T) { var gotResponses []string for i := 0; i < messagesCount; i++ { - value := &testResult{val: fmt.Sprintf("msg: %d", i)} + value := &testRequest{request: fmt.Sprintf("msg: %d", i)} p, err := producer.Produce(ctx, value) if err != nil { t.Errorf("Produce() unexpected error: %v", err) @@ -143,7 +142,7 @@ func TestProduce(t *testing.T) { if err != nil { t.Errorf("Await() unexpected error: %v", err) } - gotResponses = append(gotResponses, res.val) + gotResponses = append(gotResponses, res.response) } producer.StopWaiter.StopAndWait() @@ -219,10 +218,7 @@ func TestClaimingOwnership(t *testing.T) { if res == nil { continue } - gotMessages[idx][res.ID] = res.Value.val - if err := c.ACK(ctx, res.ID); err != nil { - t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) - } + gotMessages[idx][res.ID] = res.Value.request if err := c.SetResult(ctx, res.ID, fmt.Sprintf("result for: %v", res.ID)); err != nil { t.Errorf("Error setting a result: %v", err) } @@ -232,9 +228,9 @@ func TestClaimingOwnership(t *testing.T) { }) } - var promises []*containers.Promise[*testResult] + var promises []*containers.Promise[*testResponse] for i := 0; i < messagesCount; i++ { - value := &testResult{val: fmt.Sprintf("msg: %d", i)} + value := &testRequest{request: fmt.Sprintf("msg: %d", i)} promise, err := producer.Produce(ctx, value) if err != nil { t.Errorf("Produce() unexpected error: %v", err) @@ -248,7 +244,7 @@ func TestClaimingOwnership(t *testing.T) { t.Errorf("Await() unexpected error: %v", err) continue } - gotResponses = append(gotResponses, res.val) + gotResponses = append(gotResponses, res.response) } for { From 0455d937ffdee50f122aeb570727440312f89598 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 2 Apr 2024 12:49:04 +0200 Subject: [PATCH 19/33] Address comments --- pubsub/consumer.go | 32 ++++++------ pubsub/producer.go | 115 ++++++++++++++++++++++++------------------ pubsub/pubsub_test.go | 16 +++--- 3 files changed, 90 insertions(+), 73 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index e013314e5b..9c0edb5e9e 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -52,19 +52,19 @@ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet) { // Consumer implements a consumer for redis stream provides heartbeat to // indicate it is alive. -type Consumer[T Marshallable[T]] struct { +type Consumer[Request Marshallable[Request], Response Marshallable[Response]] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ConsumerConfig } -type Message[T Marshallable[T]] struct { +type Message[Request Marshallable[Request]] struct { ID string - Value T + Value Request } -func NewConsumer[T Marshallable[T]](ctx context.Context, cfg *ConsumerConfig) (*Consumer[T], error) { +func NewConsumer[Request Marshallable[Request], Response Marshallable[Response]](ctx context.Context, cfg *ConsumerConfig) (*Consumer[Request, Response], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -72,7 +72,7 @@ func NewConsumer[T Marshallable[T]](ctx context.Context, cfg *ConsumerConfig) (* if err != nil { return nil, err } - consumer := &Consumer[T]{ + consumer := &Consumer[Request, Response]{ id: uuid.NewString(), client: c, cfg: cfg, @@ -81,7 +81,7 @@ func NewConsumer[T Marshallable[T]](ctx context.Context, cfg *ConsumerConfig) (* } // Start starts the consumer to iteratively perform heartbeat in configured intervals. -func (c *Consumer[T]) Start(ctx context.Context) { +func (c *Consumer[Request, Response]) Start(ctx context.Context) { c.StopWaiter.Start(ctx, c) c.StopWaiter.CallIteratively( func(ctx context.Context) time.Duration { @@ -91,7 +91,7 @@ func (c *Consumer[T]) Start(ctx context.Context) { ) } -func (c *Consumer[T]) StopAndWait() { +func (c *Consumer[Request, Response]) StopAndWait() { c.StopWaiter.StopAndWait() } @@ -99,12 +99,12 @@ func heartBeatKey(id string) string { return fmt.Sprintf("consumer:%s:heartbeat", id) } -func (c *Consumer[T]) heartBeatKey() string { +func (c *Consumer[Request, Response]) heartBeatKey() string { return heartBeatKey(c.id) } // heartBeat updates the heartBeat key indicating aliveness. -func (c *Consumer[T]) heartBeat(ctx context.Context) { +func (c *Consumer[Request, Response]) heartBeat(ctx context.Context) { if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), 2*c.cfg.KeepAliveTimeout).Err(); err != nil { l := log.Info if ctx.Err() != nil { @@ -116,7 +116,7 @@ func (c *Consumer[T]) heartBeat(ctx context.Context) { // Consumer first checks it there exists pending message that is claimed by // unresponsive consumer, if not then reads from the stream. -func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) { +func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Request], error) { res, err := c.client.XReadGroup(ctx, &redis.XReadGroupArgs{ Group: c.cfg.RedisGroup, Consumer: c.id, @@ -139,24 +139,24 @@ func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) { var ( value = res[0].Messages[0].Values[messageKey] data, ok = (value).(string) - tmp T + tmp Request ) if !ok { return nil, fmt.Errorf("casting request to string: %w", err) } - val, err := tmp.Unmarshal([]byte(data)) + req, err := tmp.Unmarshal([]byte(data)) if err != nil { return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err) } - return &Message[T]{ + return &Message[Request]{ ID: res[0].Messages[0].ID, - Value: val, + Value: req, }, nil } -func (c *Consumer[T]) SetResult(ctx context.Context, messageID string, result string) error { - acquired, err := c.client.SetNX(ctx, messageID, result, c.cfg.ResponseEntryTimeout).Result() +func (c *Consumer[Request, Response]) SetResult(ctx context.Context, messageID string, result Response) error { + acquired, err := c.client.SetNX(ctx, messageID, result.Marshal(), c.cfg.ResponseEntryTimeout).Result() if err != nil || !acquired { return fmt.Errorf("setting result for message: %v, error: %w", messageID, err) } diff --git a/pubsub/producer.go b/pubsub/producer.go index 006b84709f..0e5c4475bd 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -34,6 +34,11 @@ type Producer[Request Marshallable[Request], Response Marshallable[Response]] st promisesLock sync.RWMutex promises map[string]*containers.Promise[Response] + + // Used for running checks for pending messages with inactive consumers + // and checking responses from consumers iteratively for the first time when + // Produce is called. + once sync.Once } type ProducerConfig struct { @@ -92,59 +97,61 @@ func NewProducer[Request Marshallable[Request], Response Marshallable[Response]] }, nil } -func (p *Producer[Request, Response]) Start(ctx context.Context) { - p.StopWaiter.Start(ctx, p) - p.StopWaiter.CallIteratively( - func(ctx context.Context) time.Duration { - msgs, err := p.checkPending(ctx) - if err != nil { - log.Error("Checking pending messages", "error", err) - return p.cfg.CheckPendingInterval - } - if len(msgs) == 0 { - return p.cfg.CheckPendingInterval - } - acked := make(map[string]Request) - for _, msg := range msgs { - if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil { - log.Error("ACKing message", "error", err) - continue - } - acked[msg.ID] = msg.Value - } - for k, v := range acked { - // Only re-insert messages that were removed the the pending list first. - _, err := p.reproduce(ctx, v, k) - if err != nil { - log.Error("Re-inserting pending messages with inactive consumers", "error", err) - } - } - return p.cfg.CheckPendingInterval - }, - ) - // Iteratively check whether result were returned for some queries. - p.StopWaiter.CallIteratively(func(ctx context.Context) time.Duration { - p.promisesLock.Lock() - defer p.promisesLock.Unlock() - for id, promise := range p.promises { - res, err := p.client.Get(ctx, id).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - continue - } - log.Error("Error reading value in redis", "key", id, "error", err) - } - var tmp Response - val, err := tmp.Unmarshal([]byte(res)) - if err != nil { - log.Error("Error unmarshaling", "value", res, "error", err) +// checkAndReproduce reproduce pending messages that were sent to consumers +// that are currently inactive. +func (p *Producer[Request, Response]) checkAndReproduce(ctx context.Context) time.Duration { + msgs, err := p.checkPending(ctx) + if err != nil { + log.Error("Checking pending messages", "error", err) + return p.cfg.CheckPendingInterval + } + if len(msgs) == 0 { + return p.cfg.CheckPendingInterval + } + acked := make(map[string]Request) + for _, msg := range msgs { + if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil { + log.Error("ACKing message", "error", err) + continue + } + acked[msg.ID] = msg.Value + } + for k, v := range acked { + // Only re-insert messages that were removed the the pending list first. + _, err := p.reproduce(ctx, v, k) + if err != nil { + log.Error("Re-inserting pending messages with inactive consumers", "error", err) + } + } + return p.cfg.CheckPendingInterval +} + +// checkResponses checks iteratively whether response for the promise is ready. +func (p *Producer[Request, Response]) checkResponses(ctx context.Context) time.Duration { + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + for id, promise := range p.promises { + res, err := p.client.Get(ctx, id).Result() + if err != nil { + if errors.Is(err, redis.Nil) { continue } - promise.Produce(val) - delete(p.promises, id) + log.Error("Error reading value in redis", "key", id, "error", err) } - return p.cfg.CheckResultInterval - }) + var tmp Response + val, err := tmp.Unmarshal([]byte(res)) + if err != nil { + log.Error("Error unmarshaling", "value", res, "error", err) + continue + } + promise.Produce(val) + delete(p.promises, id) + } + return p.cfg.CheckResultInterval +} + +func (p *Producer[Request, Response]) Start(ctx context.Context) { + p.StopWaiter.Start(ctx, p) } // reproduce is used when Producer claims ownership on the pending @@ -170,6 +177,10 @@ func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Reque } func (p *Producer[Request, Response]) Produce(ctx context.Context, value Request) (*containers.Promise[Response], error) { + p.once.Do(func() { + p.StopWaiter.CallIteratively(p.checkAndReproduce) + p.StopWaiter.CallIteratively(p.checkResponses) + }) return p.reproduce(ctx, value, "") } @@ -211,6 +222,10 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess } ids = append(ids, msg.ID) } + if len(ids) == 0 { + log.Info("There are no pending messages with inactive consumers") + return nil, nil + } log.Info("Attempting to claim", "messages", ids) claimedMsgs, err := p.client.XClaim(ctx, &redis.XClaimArgs{ Stream: p.cfg.RedisStream, diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 77f2a87914..e2976f3fdf 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -57,7 +57,7 @@ func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient } } -func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testRequest, *testResponse], []*Consumer[*testRequest]) { +func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testRequest, *testResponse], []*Consumer[*testRequest, *testResponse]) { t.Helper() redisURL := redisutil.CreateTestRedis(ctx, t) defaultProdCfg := DefaultTestProducerConfig @@ -68,9 +68,9 @@ func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testReq } defaultCfg := DefaultTestConsumerConfig defaultCfg.RedisURL = redisURL - var consumers []*Consumer[*testRequest] + var consumers []*Consumer[*testRequest, *testResponse] for i := 0; i < consumersCount; i++ { - c, err := NewConsumer[*testRequest](ctx, defaultCfg) + c, err := NewConsumer[*testRequest, *testResponse](ctx, defaultCfg) if err != nil { t.Fatalf("Error creating new consumer: %v", err) } @@ -122,10 +122,11 @@ func TestProduce(t *testing.T) { continue } gotMessages[idx][res.ID] = res.Value.request - if err := c.SetResult(ctx, res.ID, fmt.Sprintf("result for: %v", res.ID)); err != nil { + resp := &testResponse{response: fmt.Sprintf("result for: %v", res.ID)} + if err := c.SetResult(ctx, res.ID, resp); err != nil { t.Errorf("Error setting a result: %v", err) } - wantResponses[idx] = append(wantResponses[idx], fmt.Sprintf("result for: %v", res.ID)) + wantResponses[idx] = append(wantResponses[idx], resp.response) } }) } @@ -219,10 +220,11 @@ func TestClaimingOwnership(t *testing.T) { continue } gotMessages[idx][res.ID] = res.Value.request - if err := c.SetResult(ctx, res.ID, fmt.Sprintf("result for: %v", res.ID)); err != nil { + resp := &testResponse{response: fmt.Sprintf("result for: %v", res.ID)} + if err := c.SetResult(ctx, res.ID, resp); err != nil { t.Errorf("Error setting a result: %v", err) } - wantResponses[idx] = append(wantResponses[idx], fmt.Sprintf("result for: %v", res.ID)) + wantResponses[idx] = append(wantResponses[idx], resp.response) total.Add(1) } }) From 378906e0098a534cf9f84956526a60497335f9e6 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 2 Apr 2024 17:37:36 +0200 Subject: [PATCH 20/33] Change Info to Trace --- pubsub/producer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pubsub/producer.go b/pubsub/producer.go index 0e5c4475bd..19ee72530c 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -223,7 +223,7 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess ids = append(ids, msg.ID) } if len(ids) == 0 { - log.Info("There are no pending messages with inactive consumers") + log.Trace("There are no pending messages with inactive consumers") return nil, nil } log.Info("Attempting to claim", "messages", ids) From 33fae88f84c2037519a1a2b07ba929115776473f Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 2 Apr 2024 17:40:28 +0200 Subject: [PATCH 21/33] Ignore messages not produced by this producer --- pubsub/producer.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pubsub/producer.go b/pubsub/producer.go index 19ee72530c..f467d87260 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -212,6 +212,10 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess var ids []string active := make(map[string]bool) for _, msg := range pendingMessages { + // Ignore messages not produced by this producer. + if _, found := p.promises[msg.ID]; !found { + continue + } alive, found := active[msg.Consumer] if !found { alive = p.isConsumerAlive(ctx, msg.Consumer) From 5b5f709970dcaf2e10532deac6284a0ad0827003 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 2 Apr 2024 18:04:52 +0200 Subject: [PATCH 22/33] Address data race --- pubsub/producer.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pubsub/producer.go b/pubsub/producer.go index f467d87260..a183cdbd7b 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -193,6 +193,13 @@ func (p *Producer[Request, Response]) isConsumerAlive(ctx context.Context, consu return time.Now().UnixMilli()-val < int64(p.cfg.KeepAliveTimeout.Milliseconds()) } +func (p *Producer[Request, Response]) havePromiseFor(messageID string) bool { + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + _, found := p.promises[messageID] + return found +} + func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Message[Request], error) { pendingMessages, err := p.client.XPendingExt(ctx, &redis.XPendingExtArgs{ Stream: p.cfg.RedisStream, @@ -213,7 +220,7 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess active := make(map[string]bool) for _, msg := range pendingMessages { // Ignore messages not produced by this producer. - if _, found := p.promises[msg.ID]; !found { + if p.havePromiseFor(msg.ID) { continue } alive, found := active[msg.Consumer] From 0bd347ec405738a223195b9659747a3417397b80 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Thu, 4 Apr 2024 12:32:51 +0200 Subject: [PATCH 23/33] Implement option to error out on failed requests instead of requeueing them --- pubsub/producer.go | 41 ++++++--- pubsub/pubsub_test.go | 188 +++++++++++++++++++++++++++++------------- 2 files changed, 163 insertions(+), 66 deletions(-) diff --git a/pubsub/producer.go b/pubsub/producer.go index a183cdbd7b..6188b81dfa 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -42,7 +42,11 @@ type Producer[Request Marshallable[Request], Response Marshallable[Response]] st } type ProducerConfig struct { - RedisURL string `koanf:"redis-url"` + // When enabled, messages that are sent to consumers that later die before + // processing them, will be re-inserted into the stream to be proceesed by + // another consumer + EnableReproduce bool `koanf:"enable-reproduce"` + RedisURL string `koanf:"redis-url"` // Redis stream name. RedisStream string `koanf:"redis-stream"` // Interval duration in which producer checks for pending messages delivered @@ -58,6 +62,7 @@ type ProducerConfig struct { } var DefaultProducerConfig = &ProducerConfig{ + EnableReproduce: true, RedisStream: "default", CheckPendingInterval: time.Second, KeepAliveTimeout: 5 * time.Minute, @@ -66,6 +71,7 @@ var DefaultProducerConfig = &ProducerConfig{ } var DefaultTestProducerConfig = &ProducerConfig{ + EnableReproduce: true, RedisStream: "default", RedisGroup: defaultGroup, CheckPendingInterval: 10 * time.Millisecond, @@ -74,11 +80,12 @@ var DefaultTestProducerConfig = &ProducerConfig{ } func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) { - f.String(prefix+".redis-url", DefaultConsumerConfig.RedisURL, "redis url for redis stream") - f.Duration(prefix+".response-entry-timeout", DefaultConsumerConfig.ResponseEntryTimeout, "timeout for response entry") - f.Duration(prefix+".keepalive-timeout", DefaultConsumerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed") - f.String(prefix+".redis-stream", DefaultConsumerConfig.RedisStream, "redis stream name to read from") - f.String(prefix+".redis-group", DefaultConsumerConfig.RedisGroup, "redis stream consumer group name") + f.Bool(prefix+".enable-reproduce", DefaultProducerConfig.EnableReproduce, "when enabled, messages with dead consumer will be re-inserted into the stream") + f.String(prefix+".redis-url", DefaultProducerConfig.RedisURL, "redis url for redis stream") + f.Duration(prefix+".check-pending-interval", DefaultProducerConfig.CheckPendingInterval, "interval in which producer checks pending messages whether consumer processing them is inactive") + f.Duration(prefix+".keepalive-timeout", DefaultProducerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed") + f.String(prefix+".redis-stream", DefaultProducerConfig.RedisStream, "redis stream name to read from") + f.String(prefix+".redis-group", DefaultProducerConfig.RedisGroup, "redis stream consumer group name") } func NewProducer[Request Marshallable[Request], Response Marshallable[Response]](cfg *ProducerConfig) (*Producer[Request, Response], error) { @@ -97,6 +104,15 @@ func NewProducer[Request Marshallable[Request], Response Marshallable[Response]] }, nil } +func (p *Producer[Request, Response]) errorPromisesFor(msgs []*Message[Request]) { + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + for _, msg := range msgs { + p.promises[msg.ID].ProduceError(fmt.Errorf("internal error, consumer died while serving the request")) + delete(p.promises, msg.ID) + } +} + // checkAndReproduce reproduce pending messages that were sent to consumers // that are currently inactive. func (p *Producer[Request, Response]) checkAndReproduce(ctx context.Context) time.Duration { @@ -108,6 +124,10 @@ func (p *Producer[Request, Response]) checkAndReproduce(ctx context.Context) tim if len(msgs) == 0 { return p.cfg.CheckPendingInterval } + if !p.cfg.EnableReproduce { + p.errorPromisesFor(msgs) + return p.cfg.CheckPendingInterval + } acked := make(map[string]Request) for _, msg := range msgs { if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil { @@ -172,6 +192,7 @@ func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Reque pr := containers.NewPromise[Response](nil) promise = &pr } + delete(p.promises, oldKey) p.promises[id] = promise return promise, nil } @@ -220,7 +241,7 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess active := make(map[string]bool) for _, msg := range pendingMessages { // Ignore messages not produced by this producer. - if p.havePromiseFor(msg.ID) { + if !p.havePromiseFor(msg.ID) { continue } alive, found := active[msg.Consumer] @@ -250,12 +271,12 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess } var res []*Message[Request] for _, msg := range claimedMsgs { - data, ok := (msg.Values[messageKey]).([]byte) + data, ok := (msg.Values[messageKey]).(string) if !ok { - return nil, fmt.Errorf("casting request to bytes: %w", err) + return nil, fmt.Errorf("casting request: %v to bytes", msg.Values[messageKey]) } var tmp Request - val, err := tmp.Unmarshal(data) + val, err := tmp.Unmarshal([]byte(data)) if err != nil { return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err) } diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index e2976f3fdf..c980ff29a9 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -4,11 +4,11 @@ import ( "context" "errors" "fmt" + "os" "sort" - "sync/atomic" "testing" - "time" + "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" "github.com/google/go-cmp/cmp" "github.com/offchainlabs/nitro/util/containers" @@ -57,20 +57,32 @@ func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient } } -func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testRequest, *testResponse], []*Consumer[*testRequest, *testResponse]) { +type configOpt interface { + apply(consCfg *ConsumerConfig, prodCfg *ProducerConfig) +} + +type disableReproduce struct{} + +func (e *disableReproduce) apply(_ *ConsumerConfig, prodCfg *ProducerConfig) { + prodCfg.EnableReproduce = false +} + +func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[*testRequest, *testResponse], []*Consumer[*testRequest, *testResponse]) { t.Helper() redisURL := redisutil.CreateTestRedis(ctx, t) - defaultProdCfg := DefaultTestProducerConfig - defaultProdCfg.RedisURL = redisURL - producer, err := NewProducer[*testRequest, *testResponse](defaultProdCfg) + prodCfg, consCfg := DefaultTestProducerConfig, DefaultTestConsumerConfig + prodCfg.RedisURL, consCfg.RedisURL = redisURL, redisURL + for _, o := range opts { + o.apply(consCfg, prodCfg) + } + producer, err := NewProducer[*testRequest, *testResponse](prodCfg) if err != nil { t.Fatalf("Error creating new producer: %v", err) } - defaultCfg := DefaultTestConsumerConfig - defaultCfg.RedisURL = redisURL + var consumers []*Consumer[*testRequest, *testResponse] for i := 0; i < consumersCount; i++ { - c, err := NewConsumer[*testRequest, *testResponse](ctx, defaultCfg) + c, err := NewConsumer[*testRequest, *testResponse](ctx, consCfg) if err != nil { t.Fatalf("Error creating new consumer: %v", err) } @@ -99,7 +111,7 @@ func wantMessages(n int) []string { return ret } -func TestProduce(t *testing.T) { +func TestRedisProduce(t *testing.T) { ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) producer.Start(ctx) @@ -180,26 +192,41 @@ func flatten(responses [][]string) []string { return ret } -func TestClaimingOwnership(t *testing.T) { - ctx := context.Background() - producer, consumers := newProducerConsumers(ctx, t) - producer.Start(ctx) - gotMessages := messagesMaps(consumersCount) +func produceMessages(ctx context.Context, producer *Producer[*testRequest, *testResponse]) ([]*containers.Promise[*testResponse], error) { + var promises []*containers.Promise[*testResponse] + for i := 0; i < messagesCount; i++ { + value := &testRequest{request: fmt.Sprintf("msg: %d", i)} + promise, err := producer.Produce(ctx, value) + if err != nil { + return nil, err + } + promises = append(promises, promise) + } + return promises, nil +} - // Consumer messages in every third consumer but don't ack them to check - // that other consumers will claim ownership on those messages. - for i := 0; i < len(consumers); i += 3 { - i := i - if _, err := consumers[i].Consume(ctx); err != nil { - t.Errorf("Error consuming message: %v", err) +func awaitResponses(ctx context.Context, promises []*containers.Promise[*testResponse]) ([]string, error) { + var ( + responses []string + errs []error + ) + for _, p := range promises { + res, err := p.Await(ctx) + if err != nil { + errs = append(errs, err) + continue } - consumers[i].StopAndWait() + responses = append(responses, res.response) } - var total atomic.Uint64 + return responses, errors.Join(errs...) +} - wantResponses := make([][]string, len(consumers)) - for idx := 0; idx < len(consumers); idx++ { - if idx%3 == 0 { +func consume(ctx context.Context, t *testing.T, consumers []*Consumer[*testRequest, *testResponse], skipN int) ([]map[string]string, [][]string) { + t.Helper() + gotMessages := messagesMaps(consumersCount) + wantResponses := make([][]string, consumersCount) + for idx := 0; idx < consumersCount; idx++ { + if idx%skipN == 0 { continue } idx, c := idx, consumers[idx] @@ -225,36 +252,39 @@ func TestClaimingOwnership(t *testing.T) { t.Errorf("Error setting a result: %v", err) } wantResponses[idx] = append(wantResponses[idx], resp.response) - total.Add(1) } }) } + return gotMessages, wantResponses +} - var promises []*containers.Promise[*testResponse] - for i := 0; i < messagesCount; i++ { - value := &testRequest{request: fmt.Sprintf("msg: %d", i)} - promise, err := producer.Produce(ctx, value) - if err != nil { - t.Errorf("Produce() unexpected error: %v", err) - } - promises = append(promises, promise) +func TestRedisClaimingOwnership(t *testing.T) { + glogger := log.NewGlogHandler(log.StreamHandler(os.Stdout, log.TerminalFormat(false))) + glogger.Verbosity(log.LvlTrace) + log.Root().SetHandler(log.Handler(glogger)) + + ctx := context.Background() + producer, consumers := newProducerConsumers(ctx, t) + producer.Start(ctx) + promises, err := produceMessages(ctx, producer) + if err != nil { + t.Fatalf("Error producing messages: %v", err) } - var gotResponses []string - for _, p := range promises { - res, err := p.Await(ctx) - if err != nil { - t.Errorf("Await() unexpected error: %v", err) - continue + + // Consumer messages in every third consumer but don't ack them to check + // that other consumers will claim ownership on those messages. + for i := 0; i < len(consumers); i += 3 { + i := i + if _, err := consumers[i].Consume(ctx); err != nil { + t.Errorf("Error consuming message: %v", err) } - gotResponses = append(gotResponses, res.response) + consumers[i].StopAndWait() } - for { - if total.Load() < uint64(messagesCount) { - time.Sleep(100 * time.Millisecond) - continue - } - break + gotMessages, wantResponses := consume(ctx, t, consumers, 3) + gotResponses, err := awaitResponses(ctx, promises) + if err != nil { + t.Fatalf("Error awaiting responses: %v", err) } for _, c := range consumers { c.StopWaiter.StopAndWait() @@ -267,13 +297,61 @@ func TestClaimingOwnership(t *testing.T) { if diff := cmp.Diff(want, got); diff != "" { t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) } - WantResp := flatten(wantResponses) - sort.Slice(gotResponses, func(i, j int) bool { - return gotResponses[i] < gotResponses[j] - }) - if diff := cmp.Diff(WantResp, gotResponses); diff != "" { + wantResp := flatten(wantResponses) + sort.Strings(gotResponses) + if diff := cmp.Diff(wantResp, gotResponses); diff != "" { t.Errorf("Unexpected diff in responses:\n%s\n", diff) } + if cnt := len(producer.promises); cnt != 0 { + t.Errorf("Producer still has %d unfullfilled promises", cnt) + } +} + +func TestRedisClaimingOwnershipReproduceDisabled(t *testing.T) { + glogger := log.NewGlogHandler(log.StreamHandler(os.Stdout, log.TerminalFormat(false))) + glogger.Verbosity(log.LvlTrace) + log.Root().SetHandler(log.Handler(glogger)) + + ctx := context.Background() + producer, consumers := newProducerConsumers(ctx, t, &disableReproduce{}) + producer.Start(ctx) + promises, err := produceMessages(ctx, producer) + if err != nil { + t.Fatalf("Error producing messages: %v", err) + } + + // Consumer messages in every third consumer but don't ack them to check + // that other consumers will claim ownership on those messages. + for i := 0; i < len(consumers); i += 3 { + i := i + if _, err := consumers[i].Consume(ctx); err != nil { + t.Errorf("Error consuming message: %v", err) + } + consumers[i].StopAndWait() + } + + gotMessages, _ := consume(ctx, t, consumers, 3) + gotResponses, err := awaitResponses(ctx, promises) + if err == nil { + t.Fatalf("All promises were fullfilled with reproduce disabled and some consumers killed") + } + for _, c := range consumers { + c.StopWaiter.StopAndWait() + } + got, err := mergeValues(gotMessages) + if err != nil { + t.Fatalf("mergeMaps() unexpected error: %v", err) + } + wantMsgCnt := messagesCount - (consumersCount / 3) - (consumersCount % 3) + if len(got) != wantMsgCnt { + t.Fatalf("Got: %d messages, want %d", len(got), wantMsgCnt) + } + if len(gotResponses) != wantMsgCnt { + t.Errorf("Got %d responses want: %d\n", len(gotResponses), wantMsgCnt) + } + if cnt := len(producer.promises); cnt != 0 { + t.Errorf("Producer still has %d unfullfilled promises", cnt) + } } // mergeValues merges maps from the slice and returns their values. @@ -290,8 +368,6 @@ func mergeValues(messages []map[string]string) ([]string, error) { ret = append(ret, v) } } - sort.Slice(ret, func(i, j int) bool { - return ret[i] < ret[j] - }) + sort.Strings(ret) return ret, nil } From c8101c2ede3dd5fa03de96165937b0571d14d010 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Thu, 4 Apr 2024 13:06:33 +0200 Subject: [PATCH 24/33] Change generics to be any instead of Marshallable, introduce generic Marshaller --- pubsub/consumer.go | 15 ++++++---- pubsub/producer.go | 20 ++++++++------ pubsub/pubsub_test.go | 64 +++++++++++++++++++------------------------ 3 files changed, 48 insertions(+), 51 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 9c0edb5e9e..8ae5bcb6b7 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -52,19 +52,21 @@ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet) { // Consumer implements a consumer for redis stream provides heartbeat to // indicate it is alive. -type Consumer[Request Marshallable[Request], Response Marshallable[Response]] struct { +type Consumer[Request any, Response any] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ConsumerConfig + mReq Marshaller[Request] + mResp Marshaller[Response] } -type Message[Request Marshallable[Request]] struct { +type Message[Request any] struct { ID string Value Request } -func NewConsumer[Request Marshallable[Request], Response Marshallable[Response]](ctx context.Context, cfg *ConsumerConfig) (*Consumer[Request, Response], error) { +func NewConsumer[Request any, Response any](ctx context.Context, cfg *ConsumerConfig, mReq Marshaller[Request], mResp Marshaller[Response]) (*Consumer[Request, Response], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -76,6 +78,8 @@ func NewConsumer[Request Marshallable[Request], Response Marshallable[Response]] id: uuid.NewString(), client: c, cfg: cfg, + mReq: mReq, + mResp: mResp, } return consumer, nil } @@ -139,12 +143,11 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req var ( value = res[0].Messages[0].Values[messageKey] data, ok = (value).(string) - tmp Request ) if !ok { return nil, fmt.Errorf("casting request to string: %w", err) } - req, err := tmp.Unmarshal([]byte(data)) + req, err := c.mReq.Unmarshal([]byte(data)) if err != nil { return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err) } @@ -156,7 +159,7 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req } func (c *Consumer[Request, Response]) SetResult(ctx context.Context, messageID string, result Response) error { - acquired, err := c.client.SetNX(ctx, messageID, result.Marshal(), c.cfg.ResponseEntryTimeout).Result() + acquired, err := c.client.SetNX(ctx, messageID, c.mResp.Marshal(result), c.cfg.ResponseEntryTimeout).Result() if err != nil || !acquired { return fmt.Errorf("setting result for message: %v, error: %w", messageID, err) } diff --git a/pubsub/producer.go b/pubsub/producer.go index 6188b81dfa..4569316b4d 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -21,16 +21,18 @@ const ( defaultGroup = "default_consumer_group" ) -type Marshallable[T any] interface { - Marshal() []byte +type Marshaller[T any] interface { + Marshal(T) []byte Unmarshal(val []byte) (T, error) } -type Producer[Request Marshallable[Request], Response Marshallable[Response]] struct { +type Producer[Request any, Response any] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ProducerConfig + mReq Marshaller[Request] + mResp Marshaller[Response] promisesLock sync.RWMutex promises map[string]*containers.Promise[Response] @@ -88,7 +90,7 @@ func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) { f.String(prefix+".redis-group", DefaultProducerConfig.RedisGroup, "redis stream consumer group name") } -func NewProducer[Request Marshallable[Request], Response Marshallable[Response]](cfg *ProducerConfig) (*Producer[Request, Response], error) { +func NewProducer[Request any, Response any](cfg *ProducerConfig, mReq Marshaller[Request], mResp Marshaller[Response]) (*Producer[Request, Response], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -100,6 +102,8 @@ func NewProducer[Request Marshallable[Request], Response Marshallable[Response]] id: uuid.NewString(), client: c, cfg: cfg, + mReq: mReq, + mResp: mResp, promises: make(map[string]*containers.Promise[Response]), }, nil } @@ -158,8 +162,7 @@ func (p *Producer[Request, Response]) checkResponses(ctx context.Context) time.D } log.Error("Error reading value in redis", "key", id, "error", err) } - var tmp Response - val, err := tmp.Unmarshal([]byte(res)) + val, err := p.mResp.Unmarshal([]byte(res)) if err != nil { log.Error("Error unmarshaling", "value", res, "error", err) continue @@ -180,7 +183,7 @@ func (p *Producer[Request, Response]) Start(ctx context.Context) { func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Request, oldKey string) (*containers.Promise[Response], error) { id, err := p.client.XAdd(ctx, &redis.XAddArgs{ Stream: p.cfg.RedisStream, - Values: map[string]any{messageKey: value.Marshal()}, + Values: map[string]any{messageKey: p.mReq.Marshal(value)}, }).Result() if err != nil { return nil, fmt.Errorf("adding values to redis: %w", err) @@ -275,8 +278,7 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess if !ok { return nil, fmt.Errorf("casting request: %v to bytes", msg.Values[messageKey]) } - var tmp Request - val, err := tmp.Unmarshal([]byte(data)) + val, err := p.mReq.Unmarshal([]byte(data)) if err != nil { return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err) } diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index c980ff29a9..095d59db3b 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -21,32 +21,24 @@ var ( messagesCount = 100 ) -type testRequest struct { - request string -} +type testRequestMarshaller struct{} -func (r *testRequest) Marshal() []byte { - return []byte(r.request) +func (t *testRequestMarshaller) Marshal(val string) []byte { + return []byte(val) } -func (r *testRequest) Unmarshal(val []byte) (*testRequest, error) { - return &testRequest{ - request: string(val), - }, nil +func (t *testRequestMarshaller) Unmarshal(val []byte) (string, error) { + return string(val), nil } -type testResponse struct { - response string -} +type testResponseMarshaller struct{} -func (r *testResponse) Marshal() []byte { - return []byte(r.response) +func (t *testResponseMarshaller) Marshal(val string) []byte { + return []byte(val) } -func (r *testResponse) Unmarshal(val []byte) (*testResponse, error) { - return &testResponse{ - response: string(val), - }, nil +func (t *testResponseMarshaller) Unmarshal(val []byte) (string, error) { + return string(val), nil } func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient) { @@ -67,7 +59,7 @@ func (e *disableReproduce) apply(_ *ConsumerConfig, prodCfg *ProducerConfig) { prodCfg.EnableReproduce = false } -func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[*testRequest, *testResponse], []*Consumer[*testRequest, *testResponse]) { +func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[string, string], []*Consumer[string, string]) { t.Helper() redisURL := redisutil.CreateTestRedis(ctx, t) prodCfg, consCfg := DefaultTestProducerConfig, DefaultTestConsumerConfig @@ -75,14 +67,14 @@ func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) for _, o := range opts { o.apply(consCfg, prodCfg) } - producer, err := NewProducer[*testRequest, *testResponse](prodCfg) + producer, err := NewProducer[string, string](prodCfg, &testRequestMarshaller{}, &testResponseMarshaller{}) if err != nil { t.Fatalf("Error creating new producer: %v", err) } - var consumers []*Consumer[*testRequest, *testResponse] + var consumers []*Consumer[string, string] for i := 0; i < consumersCount; i++ { - c, err := NewConsumer[*testRequest, *testResponse](ctx, consCfg) + c, err := NewConsumer[string, string](ctx, consCfg, &testRequestMarshaller{}, &testResponseMarshaller{}) if err != nil { t.Fatalf("Error creating new consumer: %v", err) } @@ -133,12 +125,12 @@ func TestRedisProduce(t *testing.T) { if res == nil { continue } - gotMessages[idx][res.ID] = res.Value.request - resp := &testResponse{response: fmt.Sprintf("result for: %v", res.ID)} + gotMessages[idx][res.ID] = res.Value + resp := fmt.Sprintf("result for: %v", res.ID) if err := c.SetResult(ctx, res.ID, resp); err != nil { t.Errorf("Error setting a result: %v", err) } - wantResponses[idx] = append(wantResponses[idx], resp.response) + wantResponses[idx] = append(wantResponses[idx], resp) } }) } @@ -146,7 +138,7 @@ func TestRedisProduce(t *testing.T) { var gotResponses []string for i := 0; i < messagesCount; i++ { - value := &testRequest{request: fmt.Sprintf("msg: %d", i)} + value := fmt.Sprintf("msg: %d", i) p, err := producer.Produce(ctx, value) if err != nil { t.Errorf("Produce() unexpected error: %v", err) @@ -155,7 +147,7 @@ func TestRedisProduce(t *testing.T) { if err != nil { t.Errorf("Await() unexpected error: %v", err) } - gotResponses = append(gotResponses, res.response) + gotResponses = append(gotResponses, res) } producer.StopWaiter.StopAndWait() @@ -192,10 +184,10 @@ func flatten(responses [][]string) []string { return ret } -func produceMessages(ctx context.Context, producer *Producer[*testRequest, *testResponse]) ([]*containers.Promise[*testResponse], error) { - var promises []*containers.Promise[*testResponse] +func produceMessages(ctx context.Context, producer *Producer[string, string]) ([]*containers.Promise[string], error) { + var promises []*containers.Promise[string] for i := 0; i < messagesCount; i++ { - value := &testRequest{request: fmt.Sprintf("msg: %d", i)} + value := fmt.Sprintf("msg: %d", i) promise, err := producer.Produce(ctx, value) if err != nil { return nil, err @@ -205,7 +197,7 @@ func produceMessages(ctx context.Context, producer *Producer[*testRequest, *test return promises, nil } -func awaitResponses(ctx context.Context, promises []*containers.Promise[*testResponse]) ([]string, error) { +func awaitResponses(ctx context.Context, promises []*containers.Promise[string]) ([]string, error) { var ( responses []string errs []error @@ -216,12 +208,12 @@ func awaitResponses(ctx context.Context, promises []*containers.Promise[*testRes errs = append(errs, err) continue } - responses = append(responses, res.response) + responses = append(responses, res) } return responses, errors.Join(errs...) } -func consume(ctx context.Context, t *testing.T, consumers []*Consumer[*testRequest, *testResponse], skipN int) ([]map[string]string, [][]string) { +func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, string], skipN int) ([]map[string]string, [][]string) { t.Helper() gotMessages := messagesMaps(consumersCount) wantResponses := make([][]string, consumersCount) @@ -246,12 +238,12 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[*testReque if res == nil { continue } - gotMessages[idx][res.ID] = res.Value.request - resp := &testResponse{response: fmt.Sprintf("result for: %v", res.ID)} + gotMessages[idx][res.ID] = res.Value + resp := fmt.Sprintf("result for: %v", res.ID) if err := c.SetResult(ctx, res.ID, resp); err != nil { t.Errorf("Error setting a result: %v", err) } - wantResponses[idx] = append(wantResponses[idx], resp.response) + wantResponses[idx] = append(wantResponses[idx], resp) } }) } From 972b0302ec45e071d0537b234402b913c74f89d7 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Thu, 4 Apr 2024 17:02:50 +0200 Subject: [PATCH 25/33] Drop glogger in tests --- pubsub/pubsub_test.go | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 095d59db3b..f62005b2cd 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -4,11 +4,9 @@ import ( "context" "errors" "fmt" - "os" "sort" "testing" - "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" "github.com/google/go-cmp/cmp" "github.com/offchainlabs/nitro/util/containers" @@ -104,6 +102,7 @@ func wantMessages(n int) []string { } func TestRedisProduce(t *testing.T) { + t.Parallel() ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) producer.Start(ctx) @@ -213,6 +212,7 @@ func awaitResponses(ctx context.Context, promises []*containers.Promise[string]) return responses, errors.Join(errs...) } +// consume messages from every consumer except every skipNth. func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, string], skipN int) ([]map[string]string, [][]string) { t.Helper() gotMessages := messagesMaps(consumersCount) @@ -251,10 +251,6 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, st } func TestRedisClaimingOwnership(t *testing.T) { - glogger := log.NewGlogHandler(log.StreamHandler(os.Stdout, log.TerminalFormat(false))) - glogger.Verbosity(log.LvlTrace) - log.Root().SetHandler(log.Handler(glogger)) - ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) producer.Start(ctx) @@ -300,10 +296,6 @@ func TestRedisClaimingOwnership(t *testing.T) { } func TestRedisClaimingOwnershipReproduceDisabled(t *testing.T) { - glogger := log.NewGlogHandler(log.StreamHandler(os.Stdout, log.TerminalFormat(false))) - glogger.Verbosity(log.LvlTrace) - log.Root().SetHandler(log.Handler(glogger)) - ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t, &disableReproduce{}) producer.Start(ctx) From 1edbd6885e672eb228d1e016e5dabc10dd4beeb8 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Thu, 4 Apr 2024 17:03:22 +0200 Subject: [PATCH 26/33] Drop remnant code --- pubsub/pubsub_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index f62005b2cd..11d8d1d14a 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -102,7 +102,6 @@ func wantMessages(n int) []string { } func TestRedisProduce(t *testing.T) { - t.Parallel() ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) producer.Start(ctx) From 0db255f42af5d6f02e11be6124a2e933331bbf4c Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Thu, 4 Apr 2024 17:21:57 +0200 Subject: [PATCH 27/33] Make tests parallel --- pubsub/pubsub_test.go | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 11d8d1d14a..5b83923692 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -9,12 +9,12 @@ import ( "github.com/go-redis/redis/v8" "github.com/google/go-cmp/cmp" + "github.com/google/uuid" "github.com/offchainlabs/nitro/util/containers" "github.com/offchainlabs/nitro/util/redisutil" ) var ( - streamName = DefaultTestProducerConfig.RedisStream consumersCount = 10 messagesCount = 100 ) @@ -39,9 +39,9 @@ func (t *testResponseMarshaller) Unmarshal(val []byte) (string, error) { return string(val), nil } -func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient) { +func createGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) { t.Helper() - _, err := client.XGroupCreateMkStream(ctx, streamName, defaultGroup, "$").Result() + _, err := client.XGroupCreateMkStream(ctx, streamName, groupName, "$").Result() if err != nil { t.Fatalf("Error creating stream group: %v", err) } @@ -57,11 +57,31 @@ func (e *disableReproduce) apply(_ *ConsumerConfig, prodCfg *ProducerConfig) { prodCfg.EnableReproduce = false } +func producerCfg() *ProducerConfig { + return &ProducerConfig{ + EnableReproduce: DefaultTestProducerConfig.EnableReproduce, + CheckPendingInterval: DefaultTestProducerConfig.CheckPendingInterval, + KeepAliveTimeout: DefaultTestProducerConfig.KeepAliveTimeout, + CheckResultInterval: DefaultTestProducerConfig.CheckResultInterval, + } +} + +func consumerCfg() *ConsumerConfig { + return &ConsumerConfig{ + ResponseEntryTimeout: DefaultTestConsumerConfig.ResponseEntryTimeout, + KeepAliveTimeout: DefaultTestConsumerConfig.KeepAliveTimeout, + } +} + func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[string, string], []*Consumer[string, string]) { t.Helper() redisURL := redisutil.CreateTestRedis(ctx, t) - prodCfg, consCfg := DefaultTestProducerConfig, DefaultTestConsumerConfig + prodCfg, consCfg := producerCfg(), consumerCfg() prodCfg.RedisURL, consCfg.RedisURL = redisURL, redisURL + streamName := uuid.NewString() + groupName := fmt.Sprintf("group_%s", streamName) + prodCfg.RedisGroup, consCfg.RedisGroup = groupName, groupName + prodCfg.RedisStream, consCfg.RedisStream = streamName, streamName for _, o := range opts { o.apply(consCfg, prodCfg) } @@ -78,7 +98,7 @@ func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) } consumers = append(consumers, c) } - createGroup(ctx, t, producer.client) + createGroup(ctx, t, streamName, groupName, producer.client) return producer, consumers } @@ -102,6 +122,7 @@ func wantMessages(n int) []string { } func TestRedisProduce(t *testing.T) { + t.Parallel() ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) producer.Start(ctx) @@ -250,6 +271,7 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, st } func TestRedisClaimingOwnership(t *testing.T) { + t.Parallel() ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t) producer.Start(ctx) @@ -295,6 +317,7 @@ func TestRedisClaimingOwnership(t *testing.T) { } func TestRedisClaimingOwnershipReproduceDisabled(t *testing.T) { + t.Parallel() ctx := context.Background() producer, consumers := newProducerConsumers(ctx, t, &disableReproduce{}) producer.Start(ctx) From 9d450af222dafb89fc04af125e091015c31bb4a9 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Thu, 4 Apr 2024 17:40:42 +0200 Subject: [PATCH 28/33] Fix data race --- pubsub/producer.go | 6 ++++++ pubsub/pubsub_test.go | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pubsub/producer.go b/pubsub/producer.go index 4569316b4d..99c4c33438 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -177,6 +177,12 @@ func (p *Producer[Request, Response]) Start(ctx context.Context) { p.StopWaiter.Start(ctx, p) } +func (p *Producer[Request, Response]) promisesLen() int { + p.promisesLock.Lock() + defer p.promisesLock.Unlock() + return len(p.promises) +} + // reproduce is used when Producer claims ownership on the pending // message that was sent to inactive consumer and reinserts it into the stream, // so that seamlessly return the answer in the same promise. diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 5b83923692..f872f8abfb 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -311,7 +311,7 @@ func TestRedisClaimingOwnership(t *testing.T) { if diff := cmp.Diff(wantResp, gotResponses); diff != "" { t.Errorf("Unexpected diff in responses:\n%s\n", diff) } - if cnt := len(producer.promises); cnt != 0 { + if cnt := producer.promisesLen(); cnt != 0 { t.Errorf("Producer still has %d unfullfilled promises", cnt) } } From 8da1e86dac0b31321ba37a03e960f52356da7419 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Fri, 5 Apr 2024 12:32:54 +0200 Subject: [PATCH 29/33] Cleanup tests --- pubsub/pubsub_test.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index f872f8abfb..ce920757f6 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -47,6 +47,14 @@ func createGroup(ctx context.Context, t *testing.T, streamName, groupName string } } +func destroyGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) { + t.Helper() + _, err := client.XGroupDestroy(ctx, streamName, groupName).Result() + if err != nil { + t.Fatalf("Error creating stream group: %v", err) + } +} + type configOpt interface { apply(consCfg *ConsumerConfig, prodCfg *ProducerConfig) } @@ -99,6 +107,16 @@ func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) consumers = append(consumers, c) } createGroup(ctx, t, streamName, groupName, producer.client) + t.Cleanup(func() { + destroyGroup(ctx, t, streamName, groupName, producer.client) + var keys []string + for _, c := range consumers { + keys = append(keys, c.heartBeatKey()) + } + if _, err := producer.client.Del(ctx, keys...).Result(); err != nil { + t.Fatalf("Error deleting heartbeat keys: %v\n", err) + } + }) return producer, consumers } @@ -355,7 +373,7 @@ func TestRedisClaimingOwnershipReproduceDisabled(t *testing.T) { if len(gotResponses) != wantMsgCnt { t.Errorf("Got %d responses want: %d\n", len(gotResponses), wantMsgCnt) } - if cnt := len(producer.promises); cnt != 0 { + if cnt := producer.promisesLen(); cnt != 0 { t.Errorf("Producer still has %d unfullfilled promises", cnt) } } From 590ec7beaa9f6abfa399b4f0be0b52f7c2c5accc Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Tue, 9 Apr 2024 08:54:00 +0200 Subject: [PATCH 30/33] Address comments --- pubsub/producer.go | 19 +++- pubsub/pubsub_test.go | 213 ++++++++++++++++-------------------------- 2 files changed, 94 insertions(+), 138 deletions(-) diff --git a/pubsub/producer.go b/pubsub/producer.go index 99c4c33438..49a5266321 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -1,3 +1,11 @@ +// Package pubsub implements publisher/subscriber model (one to many). +// During normal operation, publisher returns "Promise" when publishing a +// message, which will return resposne from consumer when awaited. +// If the consumer processing the request becomes inactive, message is +// re-inserted (if EnableReproduce flag is enabled), and will be picked up by +// another consumer. +// We are assuming here that keeepAliveTimeout is set to some sensible value +// and once consumer becomes inactive, it doesn't activate without restart. package pubsub import ( @@ -37,7 +45,7 @@ type Producer[Request any, Response any] struct { promisesLock sync.RWMutex promises map[string]*containers.Promise[Response] - // Used for running checks for pending messages with inactive consumers + // Used for running checks for pending messages with inactive consumers // and checking responses from consumers iteratively for the first time when // Produce is called. once sync.Once @@ -112,8 +120,10 @@ func (p *Producer[Request, Response]) errorPromisesFor(msgs []*Message[Request]) p.promisesLock.Lock() defer p.promisesLock.Unlock() for _, msg := range msgs { - p.promises[msg.ID].ProduceError(fmt.Errorf("internal error, consumer died while serving the request")) - delete(p.promises, msg.ID) + if msg != nil { + p.promises[msg.ID].ProduceError(fmt.Errorf("internal error, consumer died while serving the request")) + delete(p.promises, msg.ID) + } } } @@ -197,6 +207,9 @@ func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Reque p.promisesLock.Lock() defer p.promisesLock.Unlock() promise := p.promises[oldKey] + if oldKey != "" && promise == nil { + return nil, fmt.Errorf("errror reproducing the message, could not find existing one") + } if oldKey == "" || promise == nil { pr := containers.NewPromise[Response](nil) promise = &pr diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index ce920757f6..22d8782ba5 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -7,6 +7,7 @@ import ( "sort" "testing" + "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" "github.com/google/go-cmp/cmp" "github.com/google/uuid" @@ -41,17 +42,15 @@ func (t *testResponseMarshaller) Unmarshal(val []byte) (string, error) { func createGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) { t.Helper() - _, err := client.XGroupCreateMkStream(ctx, streamName, groupName, "$").Result() - if err != nil { + if _, err := client.XGroupCreateMkStream(ctx, streamName, groupName, "$").Result(); err != nil { t.Fatalf("Error creating stream group: %v", err) } } func destroyGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) { t.Helper() - _, err := client.XGroupDestroy(ctx, streamName, groupName).Result() - if err != nil { - t.Fatalf("Error creating stream group: %v", err) + if _, err := client.XGroupDestroy(ctx, streamName, groupName).Result(); err != nil { + log.Debug("Error creating stream group: %v", err) } } @@ -108,13 +107,14 @@ func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) } createGroup(ctx, t, streamName, groupName, producer.client) t.Cleanup(func() { + ctx := context.Background() destroyGroup(ctx, t, streamName, groupName, producer.client) var keys []string for _, c := range consumers { keys = append(keys, c.heartBeatKey()) } if _, err := producer.client.Del(ctx, keys...).Result(); err != nil { - t.Fatalf("Error deleting heartbeat keys: %v\n", err) + log.Debug("Error deleting heartbeat keys", "error", err) } }) return producer, consumers @@ -133,99 +133,23 @@ func wantMessages(n int) []string { for i := 0; i < n; i++ { ret = append(ret, fmt.Sprintf("msg: %d", i)) } - sort.Slice(ret, func(i, j int) bool { - return fmt.Sprintf("%v", ret[i]) < fmt.Sprintf("%v", ret[j]) - }) + sort.Strings(ret) return ret } -func TestRedisProduce(t *testing.T) { - t.Parallel() - ctx := context.Background() - producer, consumers := newProducerConsumers(ctx, t) - producer.Start(ctx) - gotMessages := messagesMaps(consumersCount) - wantResponses := make([][]string, len(consumers)) - for idx, c := range consumers { - idx, c := idx, c - c.Start(ctx) - c.StopWaiter.LaunchThread( - func(ctx context.Context) { - for { - res, err := c.Consume(ctx) - if err != nil { - if !errors.Is(err, context.Canceled) { - t.Errorf("Consume() unexpected error: %v", err) - } - return - } - if res == nil { - continue - } - gotMessages[idx][res.ID] = res.Value - resp := fmt.Sprintf("result for: %v", res.ID) - if err := c.SetResult(ctx, res.ID, resp); err != nil { - t.Errorf("Error setting a result: %v", err) - } - wantResponses[idx] = append(wantResponses[idx], resp) - } - }) - } - - var gotResponses []string - - for i := 0; i < messagesCount; i++ { - value := fmt.Sprintf("msg: %d", i) - p, err := producer.Produce(ctx, value) - if err != nil { - t.Errorf("Produce() unexpected error: %v", err) - } - res, err := p.Await(ctx) - if err != nil { - t.Errorf("Await() unexpected error: %v", err) - } - gotResponses = append(gotResponses, res) - } - - producer.StopWaiter.StopAndWait() - for _, c := range consumers { - c.StopAndWait() - } - - got, err := mergeValues(gotMessages) - if err != nil { - t.Fatalf("mergeMaps() unexpected error: %v", err) - } - want := wantMessages(messagesCount) - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) - } - - wantResp := flatten(wantResponses) - sort.Slice(gotResponses, func(i, j int) bool { - return gotResponses[i] < gotResponses[j] - }) - if diff := cmp.Diff(wantResp, gotResponses); diff != "" { - t.Errorf("Unexpected diff in responses:\n%s\n", diff) - } -} - func flatten(responses [][]string) []string { var ret []string for _, v := range responses { ret = append(ret, v...) } - sort.Slice(ret, func(i, j int) bool { - return ret[i] < ret[j] - }) + sort.Strings(ret) return ret } -func produceMessages(ctx context.Context, producer *Producer[string, string]) ([]*containers.Promise[string], error) { +func produceMessages(ctx context.Context, msgs []string, producer *Producer[string, string]) ([]*containers.Promise[string], error) { var promises []*containers.Promise[string] for i := 0; i < messagesCount; i++ { - value := fmt.Sprintf("msg: %d", i) - promise, err := producer.Produce(ctx, value) + promise, err := producer.Produce(ctx, msgs[i]) if err != nil { return nil, err } @@ -250,13 +174,13 @@ func awaitResponses(ctx context.Context, promises []*containers.Promise[string]) return responses, errors.Join(errs...) } -// consume messages from every consumer except every skipNth. -func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, string], skipN int) ([]map[string]string, [][]string) { +// consume messages from every consumer except stopped ones. +func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, string]) ([]map[string]string, [][]string) { t.Helper() gotMessages := messagesMaps(consumersCount) wantResponses := make([][]string, consumersCount) for idx := 0; idx < consumersCount; idx++ { - if idx%skipN == 0 { + if consumers[idx].Stopped() { continue } idx, c := idx, consumers[idx] @@ -288,58 +212,78 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, st return gotMessages, wantResponses } -func TestRedisClaimingOwnership(t *testing.T) { +func TestRedisProduce(t *testing.T) { t.Parallel() - ctx := context.Background() - producer, consumers := newProducerConsumers(ctx, t) - producer.Start(ctx) - promises, err := produceMessages(ctx, producer) - if err != nil { - t.Fatalf("Error producing messages: %v", err) - } + for _, tc := range []struct { + name string + killConsumers bool + }{ + { + name: "all consumers are active", + killConsumers: false, + }, + { + name: "some consumers killed, others should take over their work", + killConsumers: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + producer, consumers := newProducerConsumers(ctx, t) + producer.Start(ctx) + wantMsgs := wantMessages(messagesCount) + promises, err := produceMessages(ctx, wantMsgs, producer) + if err != nil { + t.Fatalf("Error producing messages: %v", err) + } + if tc.killConsumers { + // Consumer messages in every third consumer but don't ack them to check + // that other consumers will claim ownership on those messages. + for i := 0; i < len(consumers); i += 3 { + if _, err := consumers[i].Consume(ctx); err != nil { + t.Errorf("Error consuming message: %v", err) + } + consumers[i].StopAndWait() + } - // Consumer messages in every third consumer but don't ack them to check - // that other consumers will claim ownership on those messages. - for i := 0; i < len(consumers); i += 3 { - i := i - if _, err := consumers[i].Consume(ctx); err != nil { - t.Errorf("Error consuming message: %v", err) - } - consumers[i].StopAndWait() - } + } + gotMessages, wantResponses := consume(ctx, t, consumers) + gotResponses, err := awaitResponses(ctx, promises) + if err != nil { + t.Fatalf("Error awaiting responses: %v", err) + } + for _, c := range consumers { + c.StopWaiter.StopAndWait() + } + got, err := mergeValues(gotMessages) + if err != nil { + t.Fatalf("mergeMaps() unexpected error: %v", err) + } - gotMessages, wantResponses := consume(ctx, t, consumers, 3) - gotResponses, err := awaitResponses(ctx, promises) - if err != nil { - t.Fatalf("Error awaiting responses: %v", err) - } - for _, c := range consumers { - c.StopWaiter.StopAndWait() - } - got, err := mergeValues(gotMessages) - if err != nil { - t.Fatalf("mergeMaps() unexpected error: %v", err) - } - want := wantMessages(messagesCount) - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) - } - wantResp := flatten(wantResponses) - sort.Strings(gotResponses) - if diff := cmp.Diff(wantResp, gotResponses); diff != "" { - t.Errorf("Unexpected diff in responses:\n%s\n", diff) - } - if cnt := producer.promisesLen(); cnt != 0 { - t.Errorf("Producer still has %d unfullfilled promises", cnt) + if diff := cmp.Diff(wantMsgs, got); diff != "" { + t.Errorf("Unexpected diff (-want +got):\n%s\n", diff) + } + wantResp := flatten(wantResponses) + sort.Strings(gotResponses) + if diff := cmp.Diff(wantResp, gotResponses); diff != "" { + t.Errorf("Unexpected diff in responses:\n%s\n", diff) + } + if cnt := producer.promisesLen(); cnt != 0 { + t.Errorf("Producer still has %d unfullfilled promises", cnt) + } + }) } } -func TestRedisClaimingOwnershipReproduceDisabled(t *testing.T) { +func TestRedisReproduceDisabled(t *testing.T) { t.Parallel() - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() producer, consumers := newProducerConsumers(ctx, t, &disableReproduce{}) producer.Start(ctx) - promises, err := produceMessages(ctx, producer) + wantMsgs := wantMessages(messagesCount) + promises, err := produceMessages(ctx, wantMsgs, producer) if err != nil { t.Fatalf("Error producing messages: %v", err) } @@ -347,14 +291,13 @@ func TestRedisClaimingOwnershipReproduceDisabled(t *testing.T) { // Consumer messages in every third consumer but don't ack them to check // that other consumers will claim ownership on those messages. for i := 0; i < len(consumers); i += 3 { - i := i if _, err := consumers[i].Consume(ctx); err != nil { t.Errorf("Error consuming message: %v", err) } consumers[i].StopAndWait() } - gotMessages, _ := consume(ctx, t, consumers, 3) + gotMessages, _ := consume(ctx, t, consumers) gotResponses, err := awaitResponses(ctx, promises) if err == nil { t.Fatalf("All promises were fullfilled with reproduce disabled and some consumers killed") @@ -366,7 +309,7 @@ func TestRedisClaimingOwnershipReproduceDisabled(t *testing.T) { if err != nil { t.Fatalf("mergeMaps() unexpected error: %v", err) } - wantMsgCnt := messagesCount - (consumersCount / 3) - (consumersCount % 3) + wantMsgCnt := messagesCount - ((consumersCount + 2) / 3) if len(got) != wantMsgCnt { t.Fatalf("Got: %d messages, want %d", len(got), wantMsgCnt) } From 6b24516f41828a8fba6ef59b89fa4bdb9b035abf Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Fri, 12 Apr 2024 11:25:26 +0200 Subject: [PATCH 31/33] Drop generic marshaller, implement jsonMarshaller instead --- pubsub/consumer.go | 10 ++++----- pubsub/producer.go | 45 +++++++++++++++++++++++++++++----------- pubsub/pubsub_test.go | 48 +++++++++++++++++-------------------------- 3 files changed, 57 insertions(+), 46 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 8ae5bcb6b7..b117215831 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -57,8 +57,8 @@ type Consumer[Request any, Response any] struct { id string client redis.UniversalClient cfg *ConsumerConfig - mReq Marshaller[Request] - mResp Marshaller[Response] + mReq jsonMarshaller[Request] + mResp jsonMarshaller[Response] } type Message[Request any] struct { @@ -66,7 +66,7 @@ type Message[Request any] struct { Value Request } -func NewConsumer[Request any, Response any](ctx context.Context, cfg *ConsumerConfig, mReq Marshaller[Request], mResp Marshaller[Response]) (*Consumer[Request, Response], error) { +func NewConsumer[Request any, Response any](ctx context.Context, cfg *ConsumerConfig) (*Consumer[Request, Response], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -78,8 +78,8 @@ func NewConsumer[Request any, Response any](ctx context.Context, cfg *ConsumerCo id: uuid.NewString(), client: c, cfg: cfg, - mReq: mReq, - mResp: mResp, + mReq: jsonMarshaller[Request]{}, + mResp: jsonMarshaller[Response]{}, } return consumer, nil } diff --git a/pubsub/producer.go b/pubsub/producer.go index 49a5266321..6118af88c4 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -10,6 +10,7 @@ package pubsub import ( "context" + "encoding/json" "errors" "fmt" "sync" @@ -29,9 +30,27 @@ const ( defaultGroup = "default_consumer_group" ) -type Marshaller[T any] interface { - Marshal(T) []byte - Unmarshal(val []byte) (T, error) +// Generic marshaller for Request and Response generic types. +// Note: unexported fields will be silently ignored. +type jsonMarshaller[T any] struct{} + +// Marshal marshals generic type object with json marshal. +func (m jsonMarshaller[T]) Marshal(v T) []byte { + data, err := json.Marshal(v) + if err != nil { + log.Error("error marshaling", "value", v, "error", err) + return nil + } + return data +} + +// Unmarshal converts a JSON byte slice back to the generic type object. +func (j jsonMarshaller[T]) Unmarshal(val []byte) (T, error) { + var v T + if err := json.Unmarshal(val, &v); err != nil { + return v, err + } + return v, nil } type Producer[Request any, Response any] struct { @@ -39,8 +58,8 @@ type Producer[Request any, Response any] struct { id string client redis.UniversalClient cfg *ProducerConfig - mReq Marshaller[Request] - mResp Marshaller[Response] + mReq jsonMarshaller[Request] + mResp jsonMarshaller[Response] promisesLock sync.RWMutex promises map[string]*containers.Promise[Response] @@ -85,7 +104,7 @@ var DefaultTestProducerConfig = &ProducerConfig{ RedisStream: "default", RedisGroup: defaultGroup, CheckPendingInterval: 10 * time.Millisecond, - KeepAliveTimeout: 20 * time.Millisecond, + KeepAliveTimeout: 100 * time.Millisecond, CheckResultInterval: 5 * time.Millisecond, } @@ -98,7 +117,7 @@ func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) { f.String(prefix+".redis-group", DefaultProducerConfig.RedisGroup, "redis stream consumer group name") } -func NewProducer[Request any, Response any](cfg *ProducerConfig, mReq Marshaller[Request], mResp Marshaller[Response]) (*Producer[Request, Response], error) { +func NewProducer[Request any, Response any](cfg *ProducerConfig) (*Producer[Request, Response], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -110,8 +129,8 @@ func NewProducer[Request any, Response any](cfg *ProducerConfig, mReq Marshaller id: uuid.NewString(), client: c, cfg: cfg, - mReq: mReq, - mResp: mResp, + mReq: jsonMarshaller[Request]{}, + mResp: jsonMarshaller[Response]{}, promises: make(map[string]*containers.Promise[Response]), }, nil } @@ -120,8 +139,8 @@ func (p *Producer[Request, Response]) errorPromisesFor(msgs []*Message[Request]) p.promisesLock.Lock() defer p.promisesLock.Unlock() for _, msg := range msgs { - if msg != nil { - p.promises[msg.ID].ProduceError(fmt.Errorf("internal error, consumer died while serving the request")) + if promise, found := p.promises[msg.ID]; found { + promise.ProduceError(fmt.Errorf("internal error, consumer died while serving the request")) delete(p.promises, msg.ID) } } @@ -208,7 +227,9 @@ func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Reque defer p.promisesLock.Unlock() promise := p.promises[oldKey] if oldKey != "" && promise == nil { - return nil, fmt.Errorf("errror reproducing the message, could not find existing one") + // This will happen if the old consumer became inactive but then ack_d + // the message afterwards. + return nil, fmt.Errorf("error reproducing the message, could not find existing one") } if oldKey == "" || promise == nil { pr := containers.NewPromise[Response](nil) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 22d8782ba5..c8968b4e45 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -20,24 +20,12 @@ var ( messagesCount = 100 ) -type testRequestMarshaller struct{} - -func (t *testRequestMarshaller) Marshal(val string) []byte { - return []byte(val) -} - -func (t *testRequestMarshaller) Unmarshal(val []byte) (string, error) { - return string(val), nil -} - -type testResponseMarshaller struct{} - -func (t *testResponseMarshaller) Marshal(val string) []byte { - return []byte(val) +type testRequest struct { + Request string } -func (t *testResponseMarshaller) Unmarshal(val []byte) (string, error) { - return string(val), nil +type testResponse struct { + Response string } func createGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) { @@ -50,7 +38,7 @@ func createGroup(ctx context.Context, t *testing.T, streamName, groupName string func destroyGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) { t.Helper() if _, err := client.XGroupDestroy(ctx, streamName, groupName).Result(); err != nil { - log.Debug("Error creating stream group: %v", err) + log.Debug("Error destroying a stream group", "error", err) } } @@ -80,7 +68,7 @@ func consumerCfg() *ConsumerConfig { } } -func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[string, string], []*Consumer[string, string]) { +func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[testRequest, testResponse], []*Consumer[testRequest, testResponse]) { t.Helper() redisURL := redisutil.CreateTestRedis(ctx, t) prodCfg, consCfg := producerCfg(), consumerCfg() @@ -92,14 +80,14 @@ func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) for _, o := range opts { o.apply(consCfg, prodCfg) } - producer, err := NewProducer[string, string](prodCfg, &testRequestMarshaller{}, &testResponseMarshaller{}) + producer, err := NewProducer[testRequest, testResponse](prodCfg) if err != nil { t.Fatalf("Error creating new producer: %v", err) } - var consumers []*Consumer[string, string] + var consumers []*Consumer[testRequest, testResponse] for i := 0; i < consumersCount; i++ { - c, err := NewConsumer[string, string](ctx, consCfg, &testRequestMarshaller{}, &testResponseMarshaller{}) + c, err := NewConsumer[testRequest, testResponse](ctx, consCfg) if err != nil { t.Fatalf("Error creating new consumer: %v", err) } @@ -146,10 +134,10 @@ func flatten(responses [][]string) []string { return ret } -func produceMessages(ctx context.Context, msgs []string, producer *Producer[string, string]) ([]*containers.Promise[string], error) { - var promises []*containers.Promise[string] +func produceMessages(ctx context.Context, msgs []string, producer *Producer[testRequest, testResponse]) ([]*containers.Promise[testResponse], error) { + var promises []*containers.Promise[testResponse] for i := 0; i < messagesCount; i++ { - promise, err := producer.Produce(ctx, msgs[i]) + promise, err := producer.Produce(ctx, testRequest{Request: msgs[i]}) if err != nil { return nil, err } @@ -158,7 +146,7 @@ func produceMessages(ctx context.Context, msgs []string, producer *Producer[stri return promises, nil } -func awaitResponses(ctx context.Context, promises []*containers.Promise[string]) ([]string, error) { +func awaitResponses(ctx context.Context, promises []*containers.Promise[testResponse]) ([]string, error) { var ( responses []string errs []error @@ -169,13 +157,13 @@ func awaitResponses(ctx context.Context, promises []*containers.Promise[string]) errs = append(errs, err) continue } - responses = append(responses, res) + responses = append(responses, res.Response) } return responses, errors.Join(errs...) } // consume messages from every consumer except stopped ones. -func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, string]) ([]map[string]string, [][]string) { +func consume(ctx context.Context, t *testing.T, consumers []*Consumer[testRequest, testResponse]) ([]map[string]string, [][]string) { t.Helper() gotMessages := messagesMaps(consumersCount) wantResponses := make([][]string, consumersCount) @@ -200,9 +188,9 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, st if res == nil { continue } - gotMessages[idx][res.ID] = res.Value + gotMessages[idx][res.ID] = res.Value.Request resp := fmt.Sprintf("result for: %v", res.ID) - if err := c.SetResult(ctx, res.ID, resp); err != nil { + if err := c.SetResult(ctx, res.ID, testResponse{Response: resp}); err != nil { t.Errorf("Error setting a result: %v", err) } wantResponses[idx] = append(wantResponses[idx], resp) @@ -253,6 +241,7 @@ func TestRedisProduce(t *testing.T) { if err != nil { t.Fatalf("Error awaiting responses: %v", err) } + producer.StopAndWait() for _, c := range consumers { c.StopWaiter.StopAndWait() } @@ -302,6 +291,7 @@ func TestRedisReproduceDisabled(t *testing.T) { if err == nil { t.Fatalf("All promises were fullfilled with reproduce disabled and some consumers killed") } + producer.StopAndWait() for _, c := range consumers { c.StopWaiter.StopAndWait() } From 92a7e3d7c085d32367461b9413e9c4c73a89d647 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Mon, 15 Apr 2024 10:11:16 +0200 Subject: [PATCH 32/33] drop generic marshaller --- pubsub/consumer.go | 29 +++++++++++++-------- pubsub/producer.go | 64 +++++++++++++++++----------------------------- 2 files changed, 41 insertions(+), 52 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index b117215831..7e21246d01 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -2,6 +2,7 @@ package pubsub import ( "context" + "encoding/json" "errors" "fmt" "time" @@ -31,13 +32,13 @@ type ConsumerConfig struct { var DefaultConsumerConfig = &ConsumerConfig{ ResponseEntryTimeout: time.Hour, KeepAliveTimeout: 5 * time.Minute, - RedisStream: "default", - RedisGroup: defaultGroup, + RedisStream: "", + RedisGroup: "", } var DefaultTestConsumerConfig = &ConsumerConfig{ - RedisStream: "default", - RedisGroup: defaultGroup, + RedisStream: "test_stream", + RedisGroup: "test_group", ResponseEntryTimeout: time.Minute, KeepAliveTimeout: 30 * time.Millisecond, } @@ -57,8 +58,6 @@ type Consumer[Request any, Response any] struct { id string client redis.UniversalClient cfg *ConsumerConfig - mReq jsonMarshaller[Request] - mResp jsonMarshaller[Response] } type Message[Request any] struct { @@ -70,6 +69,12 @@ func NewConsumer[Request any, Response any](ctx context.Context, cfg *ConsumerCo if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } + if cfg.RedisStream == "" { + return nil, fmt.Errorf("redis stream name cannot be empty") + } + if cfg.RedisGroup == "" { + return nil, fmt.Errorf("redis group name cannot be emtpy") + } c, err := redisutil.RedisClientFromURL(cfg.RedisURL) if err != nil { return nil, err @@ -78,8 +83,6 @@ func NewConsumer[Request any, Response any](ctx context.Context, cfg *ConsumerCo id: uuid.NewString(), client: c, cfg: cfg, - mReq: jsonMarshaller[Request]{}, - mResp: jsonMarshaller[Response]{}, } return consumer, nil } @@ -147,8 +150,8 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req if !ok { return nil, fmt.Errorf("casting request to string: %w", err) } - req, err := c.mReq.Unmarshal([]byte(data)) - if err != nil { + var req Request + if err := json.Unmarshal([]byte(data), &req); err != nil { return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err) } @@ -159,7 +162,11 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req } func (c *Consumer[Request, Response]) SetResult(ctx context.Context, messageID string, result Response) error { - acquired, err := c.client.SetNX(ctx, messageID, c.mResp.Marshal(result), c.cfg.ResponseEntryTimeout).Result() + resp, err := json.Marshal(result) + if err != nil { + return fmt.Errorf("marshaling result: %w", err) + } + acquired, err := c.client.SetNX(ctx, messageID, resp, c.cfg.ResponseEntryTimeout).Result() if err != nil || !acquired { return fmt.Errorf("setting result for message: %v, error: %w", messageID, err) } diff --git a/pubsub/producer.go b/pubsub/producer.go index 6118af88c4..13a4553e2f 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -30,36 +30,11 @@ const ( defaultGroup = "default_consumer_group" ) -// Generic marshaller for Request and Response generic types. -// Note: unexported fields will be silently ignored. -type jsonMarshaller[T any] struct{} - -// Marshal marshals generic type object with json marshal. -func (m jsonMarshaller[T]) Marshal(v T) []byte { - data, err := json.Marshal(v) - if err != nil { - log.Error("error marshaling", "value", v, "error", err) - return nil - } - return data -} - -// Unmarshal converts a JSON byte slice back to the generic type object. -func (j jsonMarshaller[T]) Unmarshal(val []byte) (T, error) { - var v T - if err := json.Unmarshal(val, &v); err != nil { - return v, err - } - return v, nil -} - type Producer[Request any, Response any] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ProducerConfig - mReq jsonMarshaller[Request] - mResp jsonMarshaller[Response] promisesLock sync.RWMutex promises map[string]*containers.Promise[Response] @@ -92,17 +67,17 @@ type ProducerConfig struct { var DefaultProducerConfig = &ProducerConfig{ EnableReproduce: true, - RedisStream: "default", + RedisStream: "", + RedisGroup: "", CheckPendingInterval: time.Second, KeepAliveTimeout: 5 * time.Minute, CheckResultInterval: 5 * time.Second, - RedisGroup: defaultGroup, } var DefaultTestProducerConfig = &ProducerConfig{ EnableReproduce: true, - RedisStream: "default", - RedisGroup: defaultGroup, + RedisStream: "", + RedisGroup: "", CheckPendingInterval: 10 * time.Millisecond, KeepAliveTimeout: 100 * time.Millisecond, CheckResultInterval: 5 * time.Millisecond, @@ -121,6 +96,12 @@ func NewProducer[Request any, Response any](cfg *ProducerConfig) (*Producer[Requ if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } + if cfg.RedisStream == "" { + return nil, fmt.Errorf("redis stream cannot be emtpy") + } + if cfg.RedisGroup == "" { + return nil, fmt.Errorf("redis group cannot be empty") + } c, err := redisutil.RedisClientFromURL(cfg.RedisURL) if err != nil { return nil, err @@ -129,8 +110,6 @@ func NewProducer[Request any, Response any](cfg *ProducerConfig) (*Producer[Requ id: uuid.NewString(), client: c, cfg: cfg, - mReq: jsonMarshaller[Request]{}, - mResp: jsonMarshaller[Response]{}, promises: make(map[string]*containers.Promise[Response]), }, nil } @@ -191,12 +170,12 @@ func (p *Producer[Request, Response]) checkResponses(ctx context.Context) time.D } log.Error("Error reading value in redis", "key", id, "error", err) } - val, err := p.mResp.Unmarshal([]byte(res)) - if err != nil { + var resp Response + if err := json.Unmarshal([]byte(res), &resp); err != nil { log.Error("Error unmarshaling", "value", res, "error", err) continue } - promise.Produce(val) + promise.Produce(resp) delete(p.promises, id) } return p.cfg.CheckResultInterval @@ -216,9 +195,13 @@ func (p *Producer[Request, Response]) promisesLen() int { // message that was sent to inactive consumer and reinserts it into the stream, // so that seamlessly return the answer in the same promise. func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Request, oldKey string) (*containers.Promise[Response], error) { + val, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("marshaling value: %w", err) + } id, err := p.client.XAdd(ctx, &redis.XAddArgs{ Stream: p.cfg.RedisStream, - Values: map[string]any{messageKey: p.mReq.Marshal(value)}, + Values: map[string]any{messageKey: val}, }).Result() if err != nil { return nil, fmt.Errorf("adding values to redis: %w", err) @@ -250,11 +233,10 @@ func (p *Producer[Request, Response]) Produce(ctx context.Context, value Request // Check if a consumer is with specified ID is alive. func (p *Producer[Request, Response]) isConsumerAlive(ctx context.Context, consumerID string) bool { - val, err := p.client.Get(ctx, heartBeatKey(consumerID)).Int64() - if err != nil { + if _, err := p.client.Get(ctx, heartBeatKey(consumerID)).Int64(); err != nil { return false } - return time.Now().UnixMilli()-val < int64(p.cfg.KeepAliveTimeout.Milliseconds()) + return true } func (p *Producer[Request, Response]) havePromiseFor(messageID string) bool { @@ -318,13 +300,13 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess if !ok { return nil, fmt.Errorf("casting request: %v to bytes", msg.Values[messageKey]) } - val, err := p.mReq.Unmarshal([]byte(data)) - if err != nil { + var req Request + if err := json.Unmarshal([]byte(data), &req); err != nil { return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err) } res = append(res, &Message[Request]{ ID: msg.ID, - Value: val, + Value: req, }) } return res, nil From 0180a2b7761bedd8ee5c236d9cf276fb251e7bc1 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Mon, 15 Apr 2024 10:13:11 +0200 Subject: [PATCH 33/33] don't set redis group/stream name in test config either --- pubsub/consumer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 7e21246d01..3de313f120 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -37,8 +37,8 @@ var DefaultConsumerConfig = &ConsumerConfig{ } var DefaultTestConsumerConfig = &ConsumerConfig{ - RedisStream: "test_stream", - RedisGroup: "test_group", + RedisStream: "", + RedisGroup: "", ResponseEntryTimeout: time.Minute, KeepAliveTimeout: 30 * time.Millisecond, }