Skip to content

Commit

Permalink
Merge pull request #2309 from OffchainLabs/validator-gracefull-shutdown
Browse files Browse the repository at this point in the history
Gracefully shutdown consumer on interrupts [NIT-2507]
  • Loading branch information
anodar authored May 23, 2024
2 parents ed3f9e0 + 717b00b commit 35bd2aa
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
17 changes: 14 additions & 3 deletions pubsub/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,13 @@ func NewConsumer[Request any, Response any](client redis.UniversalClient, stream
if streamName == "" {
return nil, fmt.Errorf("redis stream name cannot be empty")
}
consumer := &Consumer[Request, Response]{
return &Consumer[Request, Response]{
id: uuid.NewString(),
client: client,
redisStream: streamName,
redisGroup: streamName, // There is 1-1 mapping of redis stream and consumer group.
cfg: cfg,
}
return consumer, nil
}, nil
}

// Start starts the consumer to iteratively perform heartbeat in configured intervals.
Expand All @@ -80,6 +79,7 @@ func (c *Consumer[Request, Response]) Start(ctx context.Context) {

func (c *Consumer[Request, Response]) StopAndWait() {
c.StopWaiter.StopAndWait()
c.deleteHeartBeat(c.GetParentContext())
}

func heartBeatKey(id string) string {
Expand All @@ -90,6 +90,17 @@ func (c *Consumer[Request, Response]) heartBeatKey() string {
return heartBeatKey(c.id)
}

// deleteHeartBeat deletes the heartbeat to indicate it is being shut down.
func (c *Consumer[Request, Response]) deleteHeartBeat(ctx context.Context) {
if err := c.client.Del(ctx, c.heartBeatKey()).Err(); err != nil {
l := log.Info
if ctx.Err() != nil {
l = log.Error
}
l("Deleting heardbeat", "consumer", c.id, "error", err)
}
}

// heartBeat updates the heartBeat key indicating aliveness.
func (c *Consumer[Request, Response]) heartBeat(ctx context.Context) {
if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), 2*c.cfg.KeepAliveTimeout).Err(); err != nil {
Expand Down
10 changes: 8 additions & 2 deletions pubsub/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"errors"
"fmt"
"os"
"sort"
"testing"
"time"

"github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
Expand Down Expand Up @@ -201,6 +203,7 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[testReques
}

func TestRedisProduce(t *testing.T) {
log.SetDefault(log.NewLogger(log.NewTerminalHandlerWithLevel(os.Stderr, log.LevelTrace, true)))
t.Parallel()
for _, tc := range []struct {
name string
Expand All @@ -212,7 +215,7 @@ func TestRedisProduce(t *testing.T) {
},
{
name: "some consumers killed, others should take over their work",
killConsumers: false,
killConsumers: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
Expand All @@ -229,21 +232,23 @@ func TestRedisProduce(t *testing.T) {
// Consumer messages in every third consumer but don't ack them to check
// that other consumers will claim ownership on those messages.
for i := 0; i < len(consumers); i += 3 {
consumers[i].Start(ctx)
if _, err := consumers[i].Consume(ctx); err != nil {
t.Errorf("Error consuming message: %v", err)
}
consumers[i].StopAndWait()
}

}
time.Sleep(time.Second)
gotMessages, wantResponses := consume(ctx, t, consumers)
gotResponses, err := awaitResponses(ctx, promises)
if err != nil {
t.Fatalf("Error awaiting responses: %v", err)
}
producer.StopAndWait()
for _, c := range consumers {
c.StopWaiter.StopAndWait()
c.StopAndWait()
}
got, err := mergeValues(gotMessages)
if err != nil {
Expand Down Expand Up @@ -280,6 +285,7 @@ func TestRedisReproduceDisabled(t *testing.T) {
// Consumer messages in every third consumer but don't ack them to check
// that other consumers will claim ownership on those messages.
for i := 0; i < len(consumers); i += 3 {
consumers[i].Start(ctx)
if _, err := consumers[i].Consume(ctx); err != nil {
t.Errorf("Error consuming message: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion system_tests/block_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func testBlockValidatorSimple(t *testing.T, dasModeString string, workloadLoops
redisURL := ""
if useRedisStreams {
redisURL = redisutil.CreateTestRedis(ctx, t)
validatorConfig.BlockValidator.RedisValidationClientConfig = redis.DefaultValidationClientConfig
validatorConfig.BlockValidator.RedisValidationClientConfig = redis.TestValidationClientConfig
validatorConfig.BlockValidator.RedisValidationClientConfig.RedisURL = redisURL
}

Expand Down

0 comments on commit 35bd2aa

Please sign in to comment.