Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gracefully shutdown consumer on interrupts [NIT-2507] #2309

Merged
merged 11 commits into from
May 23, 2024
27 changes: 24 additions & 3 deletions pubsub/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"sync/atomic"
"time"

"github.com/ethereum/go-ethereum/log"
Expand Down Expand Up @@ -46,6 +47,9 @@ type Consumer[Request any, Response any] struct {
redisStream string
redisGroup string
cfg *ConsumerConfig
// terminating indicates whether interrupt was received, in which case
// consumer should clean up for graceful shutdown.
terminating atomic.Bool
anodar marked this conversation as resolved.
Show resolved Hide resolved
}

type Message[Request any] struct {
Expand All @@ -57,28 +61,33 @@ func NewConsumer[Request any, Response any](client redis.UniversalClient, stream
if streamName == "" {
return nil, fmt.Errorf("redis stream name cannot be empty")
}
consumer := &Consumer[Request, Response]{
return &Consumer[Request, Response]{
id: uuid.NewString(),
client: client,
redisStream: streamName,
redisGroup: streamName, // There is 1-1 mapping of redis stream and consumer group.
cfg: cfg,
}
return consumer, nil
terminating: atomic.Bool{},
}, nil
}

// Start starts the consumer to iteratively perform heartbeat in configured intervals.
func (c *Consumer[Request, Response]) Start(ctx context.Context) {
c.StopWaiter.Start(ctx, c)
c.StopWaiter.CallIteratively(
func(ctx context.Context) time.Duration {
if !c.terminating.Load() {
log.Trace("Consumer is terminating, stopping heartbeat update")
return time.Hour
}
c.heartBeat(ctx)
return c.cfg.KeepAliveTimeout / 10
},
)
}

func (c *Consumer[Request, Response]) StopAndWait() {
c.deleteHeartBeat(c.GetParentContext())
c.StopWaiter.StopAndWait()
}

Expand All @@ -90,6 +99,18 @@ func (c *Consumer[Request, Response]) heartBeatKey() string {
return heartBeatKey(c.id)
}

// deleteHeartBeat deletes the heartbeat to indicate it is being shut down.
func (c *Consumer[Request, Response]) deleteHeartBeat(ctx context.Context) {
c.terminating.Store(true)
anodar marked this conversation as resolved.
Show resolved Hide resolved
if err := c.client.Del(ctx, c.heartBeatKey()).Err(); err != nil {
l := log.Info
if ctx.Err() != nil {
l = log.Error
}
l("Deleting heardbeat", "consumer", c.id, "error", err)
}
}

// heartBeat updates the heartBeat key indicating aliveness.
func (c *Consumer[Request, Response]) heartBeat(ctx context.Context) {
if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), 2*c.cfg.KeepAliveTimeout).Err(); err != nil {
Expand Down
10 changes: 8 additions & 2 deletions pubsub/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"errors"
"fmt"
"os"
"sort"
"testing"
"time"

"github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
Expand Down Expand Up @@ -201,6 +203,7 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[testReques
}

func TestRedisProduce(t *testing.T) {
log.SetDefault(log.NewLogger(log.NewTerminalHandlerWithLevel(os.Stderr, log.LevelTrace, true)))
t.Parallel()
for _, tc := range []struct {
name string
Expand All @@ -212,7 +215,7 @@ func TestRedisProduce(t *testing.T) {
},
{
name: "some consumers killed, others should take over their work",
killConsumers: false,
killConsumers: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
Expand All @@ -229,21 +232,23 @@ func TestRedisProduce(t *testing.T) {
// Consumer messages in every third consumer but don't ack them to check
// that other consumers will claim ownership on those messages.
for i := 0; i < len(consumers); i += 3 {
consumers[i].Start(ctx)
if _, err := consumers[i].Consume(ctx); err != nil {
t.Errorf("Error consuming message: %v", err)
}
consumers[i].StopAndWait()
}

}
time.Sleep(time.Second)
gotMessages, wantResponses := consume(ctx, t, consumers)
gotResponses, err := awaitResponses(ctx, promises)
if err != nil {
t.Fatalf("Error awaiting responses: %v", err)
}
producer.StopAndWait()
for _, c := range consumers {
c.StopWaiter.StopAndWait()
c.StopAndWait()
}
got, err := mergeValues(gotMessages)
if err != nil {
Expand Down Expand Up @@ -280,6 +285,7 @@ func TestRedisReproduceDisabled(t *testing.T) {
// Consumer messages in every third consumer but don't ack them to check
// that other consumers will claim ownership on those messages.
for i := 0; i < len(consumers); i += 3 {
consumers[i].Start(ctx)
if _, err := consumers[i].Consume(ctx); err != nil {
t.Errorf("Error consuming message: %v", err)
}
Expand Down
Loading