Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anodar committed Apr 1, 2024
1 parent b183881 commit 2a67624
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 84 deletions.
53 changes: 34 additions & 19 deletions pubsub/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
)

type ConsumerConfig struct {
// Intervals in which consumer will update heartbeat.
KeepAliveInterval time.Duration `koanf:"keepalive-interval"`
// Timeout of result entry in Redis.
ResponseEntryTimeout time.Duration `koanf:"response-entry-timeout"`
// Duration after which consumer is considered to be dead if heartbeat
// is not updated.
KeepAliveTimeout time.Duration `koanf:"keepalive-timeout"`
Expand All @@ -28,12 +28,26 @@ type ConsumerConfig struct {
RedisGroup string `koanf:"redis-group"`
}

func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet, cfg *ConsumerConfig) {
f.Duration(prefix+".keepalive-interval", 30*time.Second, "interval in which consumer will perform heartbeat")
f.Duration(prefix+".keepalive-timeout", 5*time.Minute, "timeout after which consumer is considered inactive if heartbeat wasn't performed")
f.String(prefix+".redis-url", "", "redis url for redis stream")
f.String(prefix+".redis-stream", "default", "redis stream name to read from")
f.String(prefix+".redis-group", defaultGroup, "redis stream consumer group name")
var DefaultConsumerConfig = &ConsumerConfig{
ResponseEntryTimeout: time.Hour,
KeepAliveTimeout: 5 * time.Minute,
RedisStream: "default",
RedisGroup: defaultGroup,
}

var DefaultTestConsumerConfig = &ConsumerConfig{
RedisStream: "default",
RedisGroup: defaultGroup,
ResponseEntryTimeout: time.Minute,
KeepAliveTimeout: 30 * time.Millisecond,
}

func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet) {
f.Duration(prefix+".response-entry-timeout", DefaultConsumerConfig.ResponseEntryTimeout, "timeout for response entry")
f.Duration(prefix+".keepalive-timeout", DefaultConsumerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed")
f.String(prefix+".redis-url", DefaultConsumerConfig.RedisURL, "redis url for redis stream")
f.String(prefix+".redis-stream", DefaultConsumerConfig.RedisStream, "redis stream name to read from")
f.String(prefix+".redis-group", DefaultConsumerConfig.RedisGroup, "redis stream consumer group name")
}

// Consumer implements a consumer for redis stream provides heartbeat to
Expand Down Expand Up @@ -72,7 +86,7 @@ func (c *Consumer[T]) Start(ctx context.Context) {
c.StopWaiter.CallIteratively(
func(ctx context.Context) time.Duration {
c.heartBeat(ctx)
return c.cfg.KeepAliveInterval
return c.cfg.KeepAliveTimeout / 10
},
)
}
Expand Down Expand Up @@ -123,10 +137,14 @@ func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) {
}
log.Debug(fmt.Sprintf("Consumer: %s consuming message: %s", c.id, res[0].Messages[0].ID))
var (
value = res[0].Messages[0].Values[messageKey]
tmp T
value = res[0].Messages[0].Values[messageKey]
data, ok = (value).(string)
tmp T
)
val, err := tmp.Unmarshal(value)
if !ok {
return nil, fmt.Errorf("casting request to string: %w", err)
}
val, err := tmp.Unmarshal([]byte(data))
if err != nil {
return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err)
}
Expand All @@ -137,16 +155,13 @@ func (c *Consumer[T]) Consume(ctx context.Context) (*Message[T], error) {
}, nil
}

func (c *Consumer[T]) ACK(ctx context.Context, messageID string) error {
log.Info("ACKing message", "consumer-id", c.id, "message-sid", messageID)
_, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result()
return err
}

func (c *Consumer[T]) SetResult(ctx context.Context, messageID string, result string) error {
acquired, err := c.client.SetNX(ctx, messageID, result, c.cfg.KeepAliveTimeout).Result()
acquired, err := c.client.SetNX(ctx, messageID, result, c.cfg.ResponseEntryTimeout).Result()
if err != nil || !acquired {
return fmt.Errorf("setting result for message: %v, error: %w", messageID, err)
}
if _, err := c.client.XAck(ctx, c.cfg.RedisStream, c.cfg.RedisGroup, messageID).Result(); err != nil {
return fmt.Errorf("acking message: %v, error: %w", messageID, err)
}
return nil
}
81 changes: 57 additions & 24 deletions pubsub/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/offchainlabs/nitro/util/containers"
"github.com/offchainlabs/nitro/util/redisutil"
"github.com/offchainlabs/nitro/util/stopwaiter"
"github.com/spf13/pflag"
)

const (
Expand All @@ -21,18 +22,18 @@ const (
)

type Marshallable[T any] interface {
Marshal() any
Unmarshal(val any) (T, error)
Marshal() []byte
Unmarshal(val []byte) (T, error)
}

type Producer[T Marshallable[T]] struct {
type Producer[Request Marshallable[Request], Response Marshallable[Response]] struct {
stopwaiter.StopWaiter
id string
client redis.UniversalClient
cfg *ProducerConfig

promisesLock sync.RWMutex
promises map[string]*containers.Promise[T]
promises map[string]*containers.Promise[Response]
}

type ProducerConfig struct {
Expand All @@ -51,23 +52,47 @@ type ProducerConfig struct {
RedisGroup string `koanf:"redis-group"`
}

func NewProducer[T Marshallable[T]](cfg *ProducerConfig) (*Producer[T], error) {
var DefaultProducerConfig = &ProducerConfig{
RedisStream: "default",
CheckPendingInterval: time.Second,
KeepAliveTimeout: 5 * time.Minute,
CheckResultInterval: 5 * time.Second,
RedisGroup: defaultGroup,
}

var DefaultTestProducerConfig = &ProducerConfig{
RedisStream: "default",
RedisGroup: defaultGroup,
CheckPendingInterval: 10 * time.Millisecond,
KeepAliveTimeout: 20 * time.Millisecond,
CheckResultInterval: 5 * time.Millisecond,
}

func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) {
f.String(prefix+".redis-url", DefaultConsumerConfig.RedisURL, "redis url for redis stream")
f.Duration(prefix+".response-entry-timeout", DefaultConsumerConfig.ResponseEntryTimeout, "timeout for response entry")
f.Duration(prefix+".keepalive-timeout", DefaultConsumerConfig.KeepAliveTimeout, "timeout after which consumer is considered inactive if heartbeat wasn't performed")
f.String(prefix+".redis-stream", DefaultConsumerConfig.RedisStream, "redis stream name to read from")
f.String(prefix+".redis-group", DefaultConsumerConfig.RedisGroup, "redis stream consumer group name")
}

func NewProducer[Request Marshallable[Request], Response Marshallable[Response]](cfg *ProducerConfig) (*Producer[Request, Response], error) {
if cfg.RedisURL == "" {
return nil, fmt.Errorf("redis url cannot be empty")
}
c, err := redisutil.RedisClientFromURL(cfg.RedisURL)
if err != nil {
return nil, err
}
return &Producer[T]{
return &Producer[Request, Response]{
id: uuid.NewString(),
client: c,
cfg: cfg,
promises: make(map[string]*containers.Promise[T]),
promises: make(map[string]*containers.Promise[Response]),
}, nil
}

func (p *Producer[T]) Start(ctx context.Context) {
func (p *Producer[Request, Response]) Start(ctx context.Context) {
p.StopWaiter.Start(ctx, p)
p.StopWaiter.CallIteratively(
func(ctx context.Context) time.Duration {
Expand All @@ -79,7 +104,7 @@ func (p *Producer[T]) Start(ctx context.Context) {
if len(msgs) == 0 {
return p.cfg.CheckPendingInterval
}
acked := make(map[string]T)
acked := make(map[string]Request)
for _, msg := range msgs {
if _, err := p.client.XAck(ctx, p.cfg.RedisStream, p.cfg.RedisGroup, msg.ID).Result(); err != nil {
log.Error("ACKing message", "error", err)
Expand Down Expand Up @@ -109,8 +134,8 @@ func (p *Producer[T]) Start(ctx context.Context) {
}
log.Error("Error reading value in redis", "key", id, "error", err)
}
var tmp T
val, err := tmp.Unmarshal(res)
var tmp Response
val, err := tmp.Unmarshal([]byte(res))
if err != nil {
log.Error("Error unmarshaling", "value", res, "error", err)
continue
Expand All @@ -125,7 +150,7 @@ func (p *Producer[T]) Start(ctx context.Context) {
// reproduce is used when Producer claims ownership on the pending
// message that was sent to inactive consumer and reinserts it into the stream,
// so that seamlessly return the answer in the same promise.
func (p *Producer[T]) reproduce(ctx context.Context, value T, oldKey string) (*containers.Promise[T], error) {
func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Request, oldKey string) (*containers.Promise[Response], error) {
id, err := p.client.XAdd(ctx, &redis.XAddArgs{
Stream: p.cfg.RedisStream,
Values: map[string]any{messageKey: value.Marshal()},
Expand All @@ -137,27 +162,27 @@ func (p *Producer[T]) reproduce(ctx context.Context, value T, oldKey string) (*c
defer p.promisesLock.Unlock()
promise := p.promises[oldKey]
if oldKey == "" || promise == nil {
p := containers.NewPromise[T](nil)
promise = &p
pr := containers.NewPromise[Response](nil)
promise = &pr
}
p.promises[id] = promise
return promise, nil
}

func (p *Producer[T]) Produce(ctx context.Context, value T) (*containers.Promise[T], error) {
func (p *Producer[Request, Response]) Produce(ctx context.Context, value Request) (*containers.Promise[Response], error) {
return p.reproduce(ctx, value, "")
}

// Check if a consumer is with specified ID is alive.
func (p *Producer[T]) isConsumerAlive(ctx context.Context, consumerID string) bool {
func (p *Producer[Request, Response]) isConsumerAlive(ctx context.Context, consumerID string) bool {
val, err := p.client.Get(ctx, heartBeatKey(consumerID)).Int64()
if err != nil {
return false
}
return time.Now().UnixMilli()-val < int64(p.cfg.KeepAliveTimeout.Milliseconds())
}

func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) {
func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Message[Request], error) {
pendingMessages, err := p.client.XPendingExt(ctx, &redis.XPendingExtArgs{
Stream: p.cfg.RedisStream,
Group: p.cfg.RedisGroup,
Expand All @@ -174,12 +199,16 @@ func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) {
}
// IDs of the pending messages with inactive consumers.
var ids []string
inactive := make(map[string]bool)
active := make(map[string]bool)
for _, msg := range pendingMessages {
if !inactive[msg.Consumer] || p.isConsumerAlive(ctx, msg.Consumer) {
alive, found := active[msg.Consumer]
if !found {
alive = p.isConsumerAlive(ctx, msg.Consumer)
active[msg.Consumer] = alive
}
if alive {
continue
}
inactive[msg.Consumer] = true
ids = append(ids, msg.ID)
}
log.Info("Attempting to claim", "messages", ids)
Expand All @@ -193,14 +222,18 @@ func (p *Producer[T]) checkPending(ctx context.Context) ([]*Message[T], error) {
if err != nil {
return nil, fmt.Errorf("claiming ownership on messages: %v, error: %w", ids, err)
}
var res []*Message[T]
var res []*Message[Request]
for _, msg := range claimedMsgs {
var tmp T
val, err := tmp.Unmarshal(msg.Values[messageKey])
data, ok := (msg.Values[messageKey]).([]byte)
if !ok {
return nil, fmt.Errorf("casting request to bytes: %w", err)
}
var tmp Request
val, err := tmp.Unmarshal(data)
if err != nil {
return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err)
}
res = append(res, &Message[T]{
res = append(res, &Message[Request]{
ID: msg.ID,
Value: val,
})
Expand Down
Loading

0 comments on commit 2a67624

Please sign in to comment.