From 2a67624daad00767c9819cd09c0b02de1b7c298c Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Mon, 1 Apr 2024 22:07:20 +0200 Subject: [PATCH] Address comments --- pubsub/consumer.go | 53 ++++++++++++++++++---------- pubsub/producer.go | 81 ++++++++++++++++++++++++++++++------------- pubsub/pubsub_test.go | 78 ++++++++++++++++++++--------------------- 3 files changed, 128 insertions(+), 84 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index f7c7ef1d37..e013314e5b 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -15,8 +15,8 @@ import ( ) type ConsumerConfig struct { - // Intervals in which consumer will update heartbeat. - KeepAliveInterval time.Duration `koanf:"keepalive-interval"` + // Timeout of result entry in Redis. + ResponseEntryTimeout time.Duration `koanf:"response-entry-timeout"` // Duration after which consumer is considered to be dead if heartbeat // is not updated. KeepAliveTimeout time.Duration `koanf:"keepalive-timeout"` @@ -28,12 +28,26 @@ type ConsumerConfig struct { RedisGroup string `koanf:"redis-group"` } -func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConfig) { - f.Duration(prefix+".keepalive-interval", 30*time.Second, "interval in which consumer will perform heartbeat") - f.Duration(prefix+".keepalive-timeout", 5*time.Minute, "timeout after which consumer is considered inactive if heartbeat wasn't performed") - f.String(prefix+".redis-url", "", "redis url for redis stream") - f.String(prefix+".redis-stream", "default", "redis stream name to read from") - f.String(prefix+".redis-group", defaultGroup, "redis stream consumer group name") +var DefaultConsumerConfig = &ConsumerConfig{ + ResponseEntryTimeout: time.Hour, + KeepAliveTimeout: 5 * time.Minute, + RedisStream: "default", + RedisGroup: defaultGroup, +} + +var DefaultTestConsumerConfig = &ConsumerConfig{ + RedisStream: "default", + RedisGroup: defaultGroup, + ResponseEntryTimeout: time.Minute, + KeepAliveTimeout: 30 * time.Millisecond, +} + +func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet) { + f.Duration(prefix+".response-entry-timeout", DefaultConsumerConfig.ResponseEntryTimeout, "timeout for response entry") + f.Duration(prefix+".keepalive-timeout", DefaultConsumerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed") + f.String(prefix+".redis-url", DefaultConsumerConfig.RedisURL, "redis url for redis stream") + f.String(prefix+".redis-stream", DefaultConsumerConfig.RedisStream, "redis stream name to read from") + f.String(prefix+".redis-group", DefaultConsumerConfig.RedisGroup, "redis stream consumer group name") } // Consumer implements a consumer for redis stream provides heartbeat to @@ -72,7 +86,7 @@ func (c *Consumer[T]) Start(ctx context.Context) { c.StopWaiter.CallIteratively( func(ctx context.Context) time.Duration { c.heartBeat(ctx) - return c.cfg.KeepAliveInterval + return c.cfg.KeepAliveTimeout / 10 }, ) } @@ -123,10 +137,14 @@ func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) { } log.Debug(fmt.Sprintf("Consumer: %s consuming message: %s", c.id, res[0].Messages[0].ID)) var ( - value = res[0].Messages[0].Values[messageKey] - tmp T + value = res[0].Messages[0].Values[messageKey] + data, ok = (value).(string) + tmp T ) - val, err := tmp.Unmarshal(value) + if !ok { + return nil, fmt.Errorf("casting request to string: %w", err) + } + val, err := tmp.Unmarshal([]byte(data)) if err != nil { return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err) } @@ -137,16 +155,13 @@ func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) { }, nil } -func (c *Consumer[T]) ACK(ctx context.Context, messageID string) error { - log.Info("ACKing message", "consumer-id", c.id, "message-sid", messageID) - _, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result() - return err -} - func (c *Consumer[T]) SetResult(ctx context.Context, messageID string, result string) error { - acquired, err := c.client.SetNX(ctx, messageID, result, c.cfg.KeepAliveTimeout).Result() + acquired, err := c.client.SetNX(ctx, messageID, result, c.cfg.ResponseEntryTimeout).Result() if err != nil || !acquired { return fmt.Errorf("setting result for message: %v, error: %w", messageID, err) } + if _, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result(); err != nil { + return fmt.Errorf("acking message: %v, error: %w", messageID, err) + } return nil } diff --git a/pubsub/producer.go b/pubsub/producer.go index 79edd9eba1..006b84709f 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -13,6 +13,7 @@ import ( "github.com/offchainlabs/nitro/util/containers" "github.com/offchainlabs/nitro/util/redisutil" "github.com/offchainlabs/nitro/util/stopwaiter" + "github.com/spf13/pflag" ) const ( @@ -21,18 +22,18 @@ const ( ) type Marshallable[T any] interface { - Marshal() any - Unmarshal(val any) (T, error) + Marshal() []byte + Unmarshal(val []byte) (T, error) } -type Producer[T Marshallable[T]] struct { +type Producer[Request Marshallable[Request], Response Marshallable[Response]] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ProducerConfig promisesLock sync.RWMutex - promises map[string]*containers.Promise[T] + promises map[string]*containers.Promise[Response] } type ProducerConfig struct { @@ -51,7 +52,31 @@ type ProducerConfig struct { RedisGroup string `koanf:"redis-group"` } -func NewProducer[T Marshallable[T]](cfg *ProducerConfig) (*Producer[T], error) { +var DefaultProducerConfig = &ProducerConfig{ + RedisStream: "default", + CheckPendingInterval: time.Second, + KeepAliveTimeout: 5 * time.Minute, + CheckResultInterval: 5 * time.Second, + RedisGroup: defaultGroup, +} + +var DefaultTestProducerConfig = &ProducerConfig{ + RedisStream: "default", + RedisGroup: defaultGroup, + CheckPendingInterval: 10 * time.Millisecond, + KeepAliveTimeout: 20 * time.Millisecond, + CheckResultInterval: 5 * time.Millisecond, +} + +func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) { + f.String(prefix+".redis-url", DefaultConsumerConfig.RedisURL, "redis url for redis stream") + f.Duration(prefix+".response-entry-timeout", DefaultConsumerConfig.ResponseEntryTimeout, "timeout for response entry") + f.Duration(prefix+".keepalive-timeout", DefaultConsumerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed") + f.String(prefix+".redis-stream", DefaultConsumerConfig.RedisStream, "redis stream name to read from") + f.String(prefix+".redis-group", DefaultConsumerConfig.RedisGroup, "redis stream consumer group name") +} + +func NewProducer[Request Marshallable[Request], Response Marshallable[Response]](cfg *ProducerConfig) (*Producer[Request, Response], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -59,15 +84,15 @@ func NewProducer[T Marshallable[T]](cfg *ProducerConfig) (*Producer[T], error) { if err != nil { return nil, err } - return &Producer[T]{ + return &Producer[Request, Response]{ id: uuid.NewString(), client: c, cfg: cfg, - promises: make(map[string]*containers.Promise[T]), + promises: make(map[string]*containers.Promise[Response]), }, nil } -func (p *Producer[T]) Start(ctx context.Context) { +func (p *Producer[Request, Response]) Start(ctx context.Context) { p.StopWaiter.Start(ctx, p) p.StopWaiter.CallIteratively( func(ctx context.Context) time.Duration { @@ -79,7 +104,7 @@ func (p *Producer[T]) Start(ctx context.Context) { if len(msgs) == 0 { return p.cfg.CheckPendingInterval } - acked := make(map[string]T) + acked := make(map[string]Request) for _, msg := range msgs { if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil { log.Error("ACKing message", "error", err) @@ -109,8 +134,8 @@ func (p *Producer[T]) Start(ctx context.Context) { } log.Error("Error reading value in redis", "key", id, "error", err) } - var tmp T - val, err := tmp.Unmarshal(res) + var tmp Response + val, err := tmp.Unmarshal([]byte(res)) if err != nil { log.Error("Error unmarshaling", "value", res, "error", err) continue @@ -125,7 +150,7 @@ func (p *Producer[T]) Start(ctx context.Context) { // reproduce is used when Producer claims ownership on the pending // message that was sent to inactive consumer and reinserts it into the stream, // so that seamlessly return the answer in the same promise. -func (p *Producer[T]) reproduce(ctx context.Context, value T, oldKey string) (*containers.Promise[T], error) { +func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Request, oldKey string) (*containers.Promise[Response], error) { id, err := p.client.XAdd(ctx, &redis.XAddArgs{ Stream: p.cfg.RedisStream, Values: map[string]any{messageKey: value.Marshal()}, @@ -137,19 +162,19 @@ func (p *Producer[T]) reproduce(ctx context.Context, value T, oldKey string) (*c defer p.promisesLock.Unlock() promise := p.promises[oldKey] if oldKey == "" || promise == nil { - p := containers.NewPromise[T](nil) - promise = &p + pr := containers.NewPromise[Response](nil) + promise = &pr } p.promises[id] = promise return promise, nil } -func (p *Producer[T]) Produce(ctx context.Context, value T) (*containers.Promise[T], error) { +func (p *Producer[Request, Response]) Produce(ctx context.Context, value Request) (*containers.Promise[Response], error) { return p.reproduce(ctx, value, "") } // Check if a consumer is with specified ID is alive. -func (p *Producer[T]) isConsumerAlive(ctx context.Context, consumerID string) bool { +func (p *Producer[Request, Response]) isConsumerAlive(ctx context.Context, consumerID string) bool { val, err := p.client.Get(ctx, heartBeatKey(consumerID)).Int64() if err != nil { return false @@ -157,7 +182,7 @@ func (p *Producer[T]) isConsumerAlive(ctx context.Context, consumerID string) bo return time.Now().UnixMilli()-val < int64(p.cfg.KeepAliveTimeout.Milliseconds()) } -func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) { +func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Message[Request], error) { pendingMessages, err := p.client.XPendingExt(ctx, &redis.XPendingExtArgs{ Stream: p.cfg.RedisStream, Group: p.cfg.RedisGroup, @@ -174,12 +199,16 @@ func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) { } // IDs of the pending messages with inactive consumers. var ids []string - inactive := make(map[string]bool) + active := make(map[string]bool) for _, msg := range pendingMessages { - if !inactive[msg.Consumer] || p.isConsumerAlive(ctx, msg.Consumer) { + alive, found := active[msg.Consumer] + if !found { + alive = p.isConsumerAlive(ctx, msg.Consumer) + active[msg.Consumer] = alive + } + if alive { continue } - inactive[msg.Consumer] = true ids = append(ids, msg.ID) } log.Info("Attempting to claim", "messages", ids) @@ -193,14 +222,18 @@ func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) { if err != nil { return nil, fmt.Errorf("claiming ownership on messages: %v, error: %w", ids, err) } - var res []*Message[T] + var res []*Message[Request] for _, msg := range claimedMsgs { - var tmp T - val, err := tmp.Unmarshal(msg.Values[messageKey]) + data, ok := (msg.Values[messageKey]).([]byte) + if !ok { + return nil, fmt.Errorf("casting request to bytes: %w", err) + } + var tmp Request + val, err := tmp.Unmarshal(data) if err != nil { return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err) } - res = append(res, &Message[T]{ + res = append(res, &Message[Request]{ ID: msg.ID, Value: val, }) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 4850166ba7..77f2a87914 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -16,22 +16,36 @@ import ( ) var ( - streamName = "validator_stream" + streamName = DefaultTestProducerConfig.RedisStream consumersCount = 10 messagesCount = 100 ) -type testResult struct { - val string +type testRequest struct { + request string } -func (r *testResult) Marshal() any { - return r.val +func (r *testRequest) Marshal() []byte { + return []byte(r.request) } -func (r *testResult) Unmarshal(val any) (*testResult, error) { - return &testResult{ - val: val.(string), +func (r *testRequest) Unmarshal(val []byte) (*testRequest, error) { + return &testRequest{ + request: string(val), + }, nil +} + +type testResponse struct { + response string +} + +func (r *testResponse) Marshal() []byte { + return []byte(r.response) +} + +func (r *testResponse) Unmarshal(val []byte) (*testResponse, error) { + return &testResponse{ + response: string(val), }, nil } @@ -43,32 +57,20 @@ func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient } } -func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testResult], []*Consumer[*testResult]) { +func newProducerConsumers(ctx context.Context, t *testing.T) (*Producer[*testRequest, *testResponse], []*Consumer[*testRequest]) { t.Helper() redisURL := redisutil.CreateTestRedis(ctx, t) - producer, err := NewProducer[*testResult]( - &ProducerConfig{ - RedisURL: redisURL, - RedisStream: streamName, - RedisGroup: defaultGroup, - CheckPendingInterval: 10 * time.Millisecond, - KeepAliveTimeout: 20 * time.Millisecond, - CheckResultInterval: 5 * time.Millisecond, - }) + defaultProdCfg := DefaultTestProducerConfig + defaultProdCfg.RedisURL = redisURL + producer, err := NewProducer[*testRequest, *testResponse](defaultProdCfg) if err != nil { t.Fatalf("Error creating new producer: %v", err) } - var consumers []*Consumer[*testResult] + defaultCfg := DefaultTestConsumerConfig + defaultCfg.RedisURL = redisURL + var consumers []*Consumer[*testRequest] for i := 0; i < consumersCount; i++ { - c, err := NewConsumer[*testResult](ctx, - &ConsumerConfig{ - RedisURL: redisURL, - RedisStream: streamName, - RedisGroup: defaultGroup, - KeepAliveInterval: 5 * time.Millisecond, - KeepAliveTimeout: 30 * time.Millisecond, - }, - ) + c, err := NewConsumer[*testRequest](ctx, defaultCfg) if err != nil { t.Fatalf("Error creating new consumer: %v", err) } @@ -119,10 +121,7 @@ func TestProduce(t *testing.T) { if res == nil { continue } - gotMessages[idx][res.ID] = res.Value.val - if err := c.ACK(ctx, res.ID); err != nil { - t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) - } + gotMessages[idx][res.ID] = res.Value.request if err := c.SetResult(ctx, res.ID, fmt.Sprintf("result for: %v", res.ID)); err != nil { t.Errorf("Error setting a result: %v", err) } @@ -134,7 +133,7 @@ func TestProduce(t *testing.T) { var gotResponses []string for i := 0; i < messagesCount; i++ { - value := &testResult{val: fmt.Sprintf("msg: %d", i)} + value := &testRequest{request: fmt.Sprintf("msg: %d", i)} p, err := producer.Produce(ctx, value) if err != nil { t.Errorf("Produce() unexpected error: %v", err) @@ -143,7 +142,7 @@ func TestProduce(t *testing.T) { if err != nil { t.Errorf("Await() unexpected error: %v", err) } - gotResponses = append(gotResponses, res.val) + gotResponses = append(gotResponses, res.response) } producer.StopWaiter.StopAndWait() @@ -219,10 +218,7 @@ func TestClaimingOwnership(t *testing.T) { if res == nil { continue } - gotMessages[idx][res.ID] = res.Value.val - if err := c.ACK(ctx, res.ID); err != nil { - t.Errorf("Error ACKing message: %v, error: %v", res.ID, err) - } + gotMessages[idx][res.ID] = res.Value.request if err := c.SetResult(ctx, res.ID, fmt.Sprintf("result for: %v", res.ID)); err != nil { t.Errorf("Error setting a result: %v", err) } @@ -232,9 +228,9 @@ func TestClaimingOwnership(t *testing.T) { }) } - var promises []*containers.Promise[*testResult] + var promises []*containers.Promise[*testResponse] for i := 0; i < messagesCount; i++ { - value := &testResult{val: fmt.Sprintf("msg: %d", i)} + value := &testRequest{request: fmt.Sprintf("msg: %d", i)} promise, err := producer.Produce(ctx, value) if err != nil { t.Errorf("Produce() unexpected error: %v", err) @@ -248,7 +244,7 @@ func TestClaimingOwnership(t *testing.T) { t.Errorf("Await() unexpected error: %v", err) continue } - gotResponses = append(gotResponses, res.val) + gotResponses = append(gotResponses, res.response) } for {