diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 7a5078ee00..c9590de8e6 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -57,14 +57,13 @@ func NewConsumer[Request any, Response any](client redis.UniversalClient, stream if streamName == "" { return nil, fmt.Errorf("redis stream name cannot be empty") } - consumer := &Consumer[Request, Response]{ + return &Consumer[Request, Response]{ id: uuid.NewString(), client: client, redisStream: streamName, redisGroup: streamName, // There is 1-1 mapping of redis stream and consumer group. cfg: cfg, - } - return consumer, nil + }, nil } // Start starts the consumer to iteratively perform heartbeat in configured intervals. @@ -80,6 +79,7 @@ func (c *Consumer[Request, Response]) Start(ctx context.Context) { func (c *Consumer[Request, Response]) StopAndWait() { c.StopWaiter.StopAndWait() + c.deleteHeartBeat(c.GetParentContext()) } func heartBeatKey(id string) string { @@ -90,6 +90,17 @@ func (c *Consumer[Request, Response]) heartBeatKey() string { return heartBeatKey(c.id) } +// deleteHeartBeat deletes the heartbeat to indicate it is being shut down. +func (c *Consumer[Request, Response]) deleteHeartBeat(ctx context.Context) { + if err := c.client.Del(ctx, c.heartBeatKey()).Err(); err != nil { + l := log.Info + if ctx.Err() != nil { + l = log.Error + } + l("Deleting heardbeat", "consumer", c.id, "error", err) + } +} + // heartBeat updates the heartBeat key indicating aliveness. func (c *Consumer[Request, Response]) heartBeat(ctx context.Context) { if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), 2*c.cfg.KeepAliveTimeout).Err(); err != nil { diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 31f6d9e20a..72504602e3 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "os" "sort" "testing" + "time" "github.com/ethereum/go-ethereum/log" "github.com/go-redis/redis/v8" @@ -201,6 +203,7 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[testReques } func TestRedisProduce(t *testing.T) { + log.SetDefault(log.NewLogger(log.NewTerminalHandlerWithLevel(os.Stderr, log.LevelTrace, true))) t.Parallel() for _, tc := range []struct { name string @@ -212,7 +215,7 @@ func TestRedisProduce(t *testing.T) { }, { name: "some consumers killed, others should take over their work", - killConsumers: false, + killConsumers: true, }, } { t.Run(tc.name, func(t *testing.T) { @@ -229,6 +232,7 @@ func TestRedisProduce(t *testing.T) { // Consumer messages in every third consumer but don't ack them to check // that other consumers will claim ownership on those messages. for i := 0; i < len(consumers); i += 3 { + consumers[i].Start(ctx) if _, err := consumers[i].Consume(ctx); err != nil { t.Errorf("Error consuming message: %v", err) } @@ -236,6 +240,7 @@ func TestRedisProduce(t *testing.T) { } } + time.Sleep(time.Second) gotMessages, wantResponses := consume(ctx, t, consumers) gotResponses, err := awaitResponses(ctx, promises) if err != nil { @@ -243,7 +248,7 @@ func TestRedisProduce(t *testing.T) { } producer.StopAndWait() for _, c := range consumers { - c.StopWaiter.StopAndWait() + c.StopAndWait() } got, err := mergeValues(gotMessages) if err != nil { @@ -280,6 +285,7 @@ func TestRedisReproduceDisabled(t *testing.T) { // Consumer messages in every third consumer but don't ack them to check // that other consumers will claim ownership on those messages. for i := 0; i < len(consumers); i += 3 { + consumers[i].Start(ctx) if _, err := consumers[i].Consume(ctx); err != nil { t.Errorf("Error consuming message: %v", err) } diff --git a/system_tests/block_validator_test.go b/system_tests/block_validator_test.go index dfd892a079..debd6d4c7c 100644 --- a/system_tests/block_validator_test.go +++ b/system_tests/block_validator_test.go @@ -72,7 +72,7 @@ func testBlockValidatorSimple(t *testing.T, dasModeString string, workloadLoops redisURL := "" if useRedisStreams { redisURL = redisutil.CreateTestRedis(ctx, t) - validatorConfig.BlockValidator.RedisValidationClientConfig = redis.DefaultValidationClientConfig + validatorConfig.BlockValidator.RedisValidationClientConfig = redis.TestValidationClientConfig validatorConfig.BlockValidator.RedisValidationClientConfig.RedisURL = redisURL }