Skip to content

Commit

Permalink
Change generics to be any instead of Marshallable, introduce generic …
Browse files Browse the repository at this point in the history
…Marshaller
  • Loading branch information
anodar committed Apr 4, 2024
1 parent 0bd347e commit c8101c2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 51 deletions.
15 changes: 9 additions & 6 deletions pubsub/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,21 @@ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet) {

// Consumer implements a consumer for redis stream provides heartbeat to
// indicate it is alive.
type Consumer[Request Marshallable[Request], Response Marshallable[Response]] struct {
type Consumer[Request any, Response any] struct {
stopwaiter.StopWaiter
id string
client redis.UniversalClient
cfg *ConsumerConfig
mReq Marshaller[Request]
mResp Marshaller[Response]
}

type Message[Request Marshallable[Request]] struct {
type Message[Request any] struct {
ID string
Value Request
}

func NewConsumer[Request Marshallable[Request], Response Marshallable[Response]](ctx context.Context, cfg *ConsumerConfig) (*Consumer[Request, Response], error) {
func NewConsumer[Request any, Response any](ctx context.Context, cfg *ConsumerConfig, mReq Marshaller[Request], mResp Marshaller[Response]) (*Consumer[Request, Response], error) {
if cfg.RedisURL == "" {
return nil, fmt.Errorf("redis url cannot be empty")
}
Expand All @@ -76,6 +78,8 @@ func NewConsumer[Request Marshallable[Request], Response Marshallable[Response]]
id: uuid.NewString(),
client: c,
cfg: cfg,
mReq: mReq,
mResp: mResp,
}
return consumer, nil
}
Expand Down Expand Up @@ -139,12 +143,11 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req
var (
value = res[0].Messages[0].Values[messageKey]
data, ok = (value).(string)
tmp Request
)
if !ok {
return nil, fmt.Errorf("casting request to string: %w", err)
}
req, err := tmp.Unmarshal([]byte(data))
req, err := c.mReq.Unmarshal([]byte(data))
if err != nil {
return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err)
}
Expand All @@ -156,7 +159,7 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req
}

func (c *Consumer[Request, Response]) SetResult(ctx context.Context, messageID string, result Response) error {
acquired, err := c.client.SetNX(ctx, messageID, result.Marshal(), c.cfg.ResponseEntryTimeout).Result()
acquired, err := c.client.SetNX(ctx, messageID, c.mResp.Marshal(result), c.cfg.ResponseEntryTimeout).Result()
if err != nil || !acquired {
return fmt.Errorf("setting result for message: %v, error: %w", messageID, err)
}
Expand Down
20 changes: 11 additions & 9 deletions pubsub/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ const (
defaultGroup = "default_consumer_group"
)

type Marshallable[T any] interface {
Marshal() []byte
type Marshaller[T any] interface {
Marshal(T) []byte
Unmarshal(val []byte) (T, error)
}

type Producer[Request Marshallable[Request], Response Marshallable[Response]] struct {
type Producer[Request any, Response any] struct {
stopwaiter.StopWaiter
id string
client redis.UniversalClient
cfg *ProducerConfig
mReq Marshaller[Request]
mResp Marshaller[Response]

promisesLock sync.RWMutex
promises map[string]*containers.Promise[Response]
Expand Down Expand Up @@ -88,7 +90,7 @@ func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) {
f.String(prefix+".redis-group", DefaultProducerConfig.RedisGroup, "redis stream consumer group name")
}

func NewProducer[Request Marshallable[Request], Response Marshallable[Response]](cfg *ProducerConfig) (*Producer[Request, Response], error) {
func NewProducer[Request any, Response any](cfg *ProducerConfig, mReq Marshaller[Request], mResp Marshaller[Response]) (*Producer[Request, Response], error) {
if cfg.RedisURL == "" {
return nil, fmt.Errorf("redis url cannot be empty")
}
Expand All @@ -100,6 +102,8 @@ func NewProducer[Request Marshallable[Request], Response Marshallable[Response]]
id: uuid.NewString(),
client: c,
cfg: cfg,
mReq: mReq,
mResp: mResp,
promises: make(map[string]*containers.Promise[Response]),
}, nil
}
Expand Down Expand Up @@ -158,8 +162,7 @@ func (p *Producer[Request, Response]) checkResponses(ctx context.Context) time.D
}
log.Error("Error reading value in redis", "key", id, "error", err)
}
var tmp Response
val, err := tmp.Unmarshal([]byte(res))
val, err := p.mResp.Unmarshal([]byte(res))
if err != nil {
log.Error("Error unmarshaling", "value", res, "error", err)
continue
Expand All @@ -180,7 +183,7 @@ func (p *Producer[Request, Response]) Start(ctx context.Context) {
func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Request, oldKey string) (*containers.Promise[Response], error) {
id, err := p.client.XAdd(ctx, &redis.XAddArgs{
Stream: p.cfg.RedisStream,
Values: map[string]any{messageKey: value.Marshal()},
Values: map[string]any{messageKey: p.mReq.Marshal(value)},
}).Result()
if err != nil {
return nil, fmt.Errorf("adding values to redis: %w", err)
Expand Down Expand Up @@ -275,8 +278,7 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess
if !ok {
return nil, fmt.Errorf("casting request: %v to bytes", msg.Values[messageKey])
}
var tmp Request
val, err := tmp.Unmarshal([]byte(data))
val, err := p.mReq.Unmarshal([]byte(data))
if err != nil {
return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err)
}
Expand Down
64 changes: 28 additions & 36 deletions pubsub/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,24 @@ var (
messagesCount = 100
)

type testRequest struct {
request string
}
type testRequestMarshaller struct{}

func (r *testRequest) Marshal() []byte {
return []byte(r.request)
func (t *testRequestMarshaller) Marshal(val string) []byte {
return []byte(val)
}

func (r *testRequest) Unmarshal(val []byte) (*testRequest, error) {
return &testRequest{
request: string(val),
}, nil
func (t *testRequestMarshaller) Unmarshal(val []byte) (string, error) {
return string(val), nil
}

type testResponse struct {
response string
}
type testResponseMarshaller struct{}

func (r *testResponse) Marshal() []byte {
return []byte(r.response)
func (t *testResponseMarshaller) Marshal(val string) []byte {
return []byte(val)
}

func (r *testResponse) Unmarshal(val []byte) (*testResponse, error) {
return &testResponse{
response: string(val),
}, nil
func (t *testResponseMarshaller) Unmarshal(val []byte) (string, error) {
return string(val), nil
}

func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient) {
Expand All @@ -67,22 +59,22 @@ func (e *disableReproduce) apply(_ *ConsumerConfig, prodCfg *ProducerConfig) {
prodCfg.EnableReproduce = false
}

func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[*testRequest, *testResponse], []*Consumer[*testRequest, *testResponse]) {
func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[string, string], []*Consumer[string, string]) {
t.Helper()
redisURL := redisutil.CreateTestRedis(ctx, t)
prodCfg, consCfg := DefaultTestProducerConfig, DefaultTestConsumerConfig
prodCfg.RedisURL, consCfg.RedisURL = redisURL, redisURL
for _, o := range opts {
o.apply(consCfg, prodCfg)
}
producer, err := NewProducer[*testRequest, *testResponse](prodCfg)
producer, err := NewProducer[string, string](prodCfg, &testRequestMarshaller{}, &testResponseMarshaller{})
if err != nil {
t.Fatalf("Error creating new producer: %v", err)
}

var consumers []*Consumer[*testRequest, *testResponse]
var consumers []*Consumer[string, string]
for i := 0; i < consumersCount; i++ {
c, err := NewConsumer[*testRequest, *testResponse](ctx, consCfg)
c, err := NewConsumer[string, string](ctx, consCfg, &testRequestMarshaller{}, &testResponseMarshaller{})
if err != nil {
t.Fatalf("Error creating new consumer: %v", err)
}
Expand Down Expand Up @@ -133,20 +125,20 @@ func TestRedisProduce(t *testing.T) {
if res == nil {
continue
}
gotMessages[idx][res.ID] = res.Value.request
resp := &testResponse{response: fmt.Sprintf("result for: %v", res.ID)}
gotMessages[idx][res.ID] = res.Value
resp := fmt.Sprintf("result for: %v", res.ID)
if err := c.SetResult(ctx, res.ID, resp); err != nil {
t.Errorf("Error setting a result: %v", err)
}
wantResponses[idx] = append(wantResponses[idx], resp.response)
wantResponses[idx] = append(wantResponses[idx], resp)
}
})
}

var gotResponses []string

for i := 0; i < messagesCount; i++ {
value := &testRequest{request: fmt.Sprintf("msg: %d", i)}
value := fmt.Sprintf("msg: %d", i)
p, err := producer.Produce(ctx, value)
if err != nil {
t.Errorf("Produce() unexpected error: %v", err)
Expand All @@ -155,7 +147,7 @@ func TestRedisProduce(t *testing.T) {
if err != nil {
t.Errorf("Await() unexpected error: %v", err)
}
gotResponses = append(gotResponses, res.response)
gotResponses = append(gotResponses, res)
}

producer.StopWaiter.StopAndWait()
Expand Down Expand Up @@ -192,10 +184,10 @@ func flatten(responses [][]string) []string {
return ret
}

func produceMessages(ctx context.Context, producer *Producer[*testRequest, *testResponse]) ([]*containers.Promise[*testResponse], error) {
var promises []*containers.Promise[*testResponse]
func produceMessages(ctx context.Context, producer *Producer[string, string]) ([]*containers.Promise[string], error) {
var promises []*containers.Promise[string]
for i := 0; i < messagesCount; i++ {
value := &testRequest{request: fmt.Sprintf("msg: %d", i)}
value := fmt.Sprintf("msg: %d", i)
promise, err := producer.Produce(ctx, value)
if err != nil {
return nil, err
Expand All @@ -205,7 +197,7 @@ func produceMessages(ctx context.Context, producer *Producer[*testRequest, *test
return promises, nil
}

func awaitResponses(ctx context.Context, promises []*containers.Promise[*testResponse]) ([]string, error) {
func awaitResponses(ctx context.Context, promises []*containers.Promise[string]) ([]string, error) {
var (
responses []string
errs []error
Expand All @@ -216,12 +208,12 @@ func awaitResponses(ctx context.Context, promises []*containers.Promise[*testRes
errs = append(errs, err)
continue
}
responses = append(responses, res.response)
responses = append(responses, res)
}
return responses, errors.Join(errs...)
}

func consume(ctx context.Context, t *testing.T, consumers []*Consumer[*testRequest, *testResponse], skipN int) ([]map[string]string, [][]string) {
func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, string], skipN int) ([]map[string]string, [][]string) {
t.Helper()
gotMessages := messagesMaps(consumersCount)
wantResponses := make([][]string, consumersCount)
Expand All @@ -246,12 +238,12 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[*testReque
if res == nil {
continue
}
gotMessages[idx][res.ID] = res.Value.request
resp := &testResponse{response: fmt.Sprintf("result for: %v", res.ID)}
gotMessages[idx][res.ID] = res.Value
resp := fmt.Sprintf("result for: %v", res.ID)
if err := c.SetResult(ctx, res.ID, resp); err != nil {
t.Errorf("Error setting a result: %v", err)
}
wantResponses[idx] = append(wantResponses[idx], resp.response)
wantResponses[idx] = append(wantResponses[idx], resp)
}
})
}
Expand Down

0 comments on commit c8101c2

Please sign in to comment.