Skip to content

Commit

Permalink
Gracefully shutdown consumer on interrupts
Browse files Browse the repository at this point in the history
  • Loading branch information
anodar committed May 16, 2024
1 parent a195604 commit 4591061
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
48 changes: 45 additions & 3 deletions pubsub/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"encoding/json"
"errors"
"fmt"
"os"
"os/signal"
"sync/atomic"
"syscall"
"time"

"github.com/ethereum/go-ethereum/log"
Expand Down Expand Up @@ -46,6 +50,10 @@ type Consumer[Request any, Response any] struct {
redisStream string
redisGroup string
cfg *ConsumerConfig
// terminating indicates whether interrupt was received, in which case
// consumer should clean up for graceful shutdown.
terminating atomic.Bool
signals chan os.Signal
}

type Message[Request any] struct {
Expand All @@ -57,29 +65,51 @@ func NewConsumer[Request any, Response any](client redis.UniversalClient, stream
if streamName == "" {
return nil, fmt.Errorf("redis stream name cannot be empty")
}
consumer := &Consumer[Request, Response]{
return &Consumer[Request, Response]{
id: uuid.NewString(),
client: client,
redisStream: streamName,
redisGroup: streamName, // There is 1-1 mapping of redis stream and consumer group.
cfg: cfg,
}
return consumer, nil
terminating: atomic.Bool{},
signals: make(chan os.Signal, 1),
}, nil
}

// Start starts the consumer to iteratively perform heartbeat in configured intervals.
func (c *Consumer[Request, Response]) Start(ctx context.Context) {
c.StopWaiter.Start(ctx, c)
c.listenForInterrupt()
c.StopWaiter.CallIteratively(
func(ctx context.Context) time.Duration {
if !c.terminating.Load() {
log.Trace("Consumer is terminating, stopping heartbeat update")
return time.Hour
}
c.heartBeat(ctx)
return c.cfg.KeepAliveTimeout / 10
},
)
}

// listenForInterrupt launches a thread that notifies the channel when interrupt
// is received.
func (c *Consumer[Request, Response]) listenForInterrupt() {
signal.Notify(c.signals, syscall.SIGINT, syscall.SIGTERM)
c.StopWaiter.LaunchThread(func(ctx context.Context) {
select {
case sig := <-c.signals:
log.Info("Received interrup", "signal", sig.String())
case <-ctx.Done():
log.Info("Context is done", "error", ctx.Err())
}
c.deleteHeartBeat(ctx)
})
}

func (c *Consumer[Request, Response]) StopAndWait() {
c.StopWaiter.StopAndWait()
c.deleteHeartBeat(c.GetContext())
}

func heartBeatKey(id string) string {
Expand All @@ -90,6 +120,18 @@ func (c *Consumer[Request, Response]) heartBeatKey() string {
return heartBeatKey(c.id)
}

// deleteHeartBeat deletes the heartbeat to indicate it is being shut down.
func (c *Consumer[Request, Response]) deleteHeartBeat(ctx context.Context) {
c.terminating.Store(true)
if err := c.client.Del(ctx, c.heartBeatKey()).Err(); err != nil {
l := log.Info
if ctx.Err() != nil {
l = log.Error
}
l("Deleting heardbeat", "consumer", c.id, "error", err)
}
}

// heartBeat updates the heartBeat key indicating aliveness.
func (c *Consumer[Request, Response]) heartBeat(ctx context.Context) {
if err := c.client.Set(ctx, c.heartBeatKey(), time.Now().UnixMilli(), 2*c.cfg.KeepAliveTimeout).Err(); err != nil {
Expand Down
8 changes: 7 additions & 1 deletion pubsub/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"os"
"sort"
"testing"

Expand Down Expand Up @@ -232,7 +233,12 @@ func TestRedisProduce(t *testing.T) {
if _, err := consumers[i].Consume(ctx); err != nil {
t.Errorf("Error consuming message: %v", err)
}
consumers[i].StopAndWait()
// Terminate half of the consumers, send interrupt to others.
if i%2 == 0 {
consumers[i].StopAndWait()
} else {
consumers[i].signals <- os.Interrupt
}
}

}
Expand Down

0 comments on commit 4591061

Please sign in to comment.