From c8101c2ede3dd5fa03de96165937b0571d14d010 Mon Sep 17 00:00:00 2001 From: Nodar Ambroladze Date: Thu, 4 Apr 2024 13:06:33 +0200 Subject: [PATCH] Change generics to be any instead of Marshallable, introduce generic Marshaller --- pubsub/consumer.go | 15 ++++++---- pubsub/producer.go | 20 ++++++++------ pubsub/pubsub_test.go | 64 +++++++++++++++++++------------------------ 3 files changed, 48 insertions(+), 51 deletions(-) diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 9c0edb5e9e..8ae5bcb6b7 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -52,19 +52,21 @@ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet) { // Consumer implements a consumer for redis stream provides heartbeat to // indicate it is alive. -type Consumer[Request Marshallable[Request], Response Marshallable[Response]] struct { +type Consumer[Request any, Response any] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ConsumerConfig + mReq Marshaller[Request] + mResp Marshaller[Response] } -type Message[Request Marshallable[Request]] struct { +type Message[Request any] struct { ID string Value Request } -func NewConsumer[Request Marshallable[Request], Response Marshallable[Response]](ctx context.Context, cfg *ConsumerConfig) (*Consumer[Request, Response], error) { +func NewConsumer[Request any, Response any](ctx context.Context, cfg *ConsumerConfig, mReq Marshaller[Request], mResp Marshaller[Response]) (*Consumer[Request, Response], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -76,6 +78,8 @@ func NewConsumer[Request Marshallable[Request], Response Marshallable[Response]] id: uuid.NewString(), client: c, cfg: cfg, + mReq: mReq, + mResp: mResp, } return consumer, nil } @@ -139,12 +143,11 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req var ( value = res[0].Messages[0].Values[messageKey] data, ok = (value).(string) - tmp Request ) if !ok { return nil, fmt.Errorf("casting request to string: %w", err) } - req, err := tmp.Unmarshal([]byte(data)) + req, err := c.mReq.Unmarshal([]byte(data)) if err != nil { return nil, fmt.Errorf("unmarshaling value: %v, error: %w", value, err) } @@ -156,7 +159,7 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req } func (c *Consumer[Request, Response]) SetResult(ctx context.Context, messageID string, result Response) error { - acquired, err := c.client.SetNX(ctx, messageID, result.Marshal(), c.cfg.ResponseEntryTimeout).Result() + acquired, err := c.client.SetNX(ctx, messageID, c.mResp.Marshal(result), c.cfg.ResponseEntryTimeout).Result() if err != nil || !acquired { return fmt.Errorf("setting result for message: %v, error: %w", messageID, err) } diff --git a/pubsub/producer.go b/pubsub/producer.go index 6188b81dfa..4569316b4d 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -21,16 +21,18 @@ const ( defaultGroup = "default_consumer_group" ) -type Marshallable[T any] interface { - Marshal() []byte +type Marshaller[T any] interface { + Marshal(T) []byte Unmarshal(val []byte) (T, error) } -type Producer[Request Marshallable[Request], Response Marshallable[Response]] struct { +type Producer[Request any, Response any] struct { stopwaiter.StopWaiter id string client redis.UniversalClient cfg *ProducerConfig + mReq Marshaller[Request] + mResp Marshaller[Response] promisesLock sync.RWMutex promises map[string]*containers.Promise[Response] @@ -88,7 +90,7 @@ func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) { f.String(prefix+".redis-group", DefaultProducerConfig.RedisGroup, "redis stream consumer group name") } -func NewProducer[Request Marshallable[Request], Response Marshallable[Response]](cfg *ProducerConfig) (*Producer[Request, Response], error) { +func NewProducer[Request any, Response any](cfg *ProducerConfig, mReq Marshaller[Request], mResp Marshaller[Response]) (*Producer[Request, Response], error) { if cfg.RedisURL == "" { return nil, fmt.Errorf("redis url cannot be empty") } @@ -100,6 +102,8 @@ func NewProducer[Request Marshallable[Request], Response Marshallable[Response]] id: uuid.NewString(), client: c, cfg: cfg, + mReq: mReq, + mResp: mResp, promises: make(map[string]*containers.Promise[Response]), }, nil } @@ -158,8 +162,7 @@ func (p *Producer[Request, Response]) checkResponses(ctx context.Context) time.D } log.Error("Error reading value in redis", "key", id, "error", err) } - var tmp Response - val, err := tmp.Unmarshal([]byte(res)) + val, err := p.mResp.Unmarshal([]byte(res)) if err != nil { log.Error("Error unmarshaling", "value", res, "error", err) continue @@ -180,7 +183,7 @@ func (p *Producer[Request, Response]) Start(ctx context.Context) { func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Request, oldKey string) (*containers.Promise[Response], error) { id, err := p.client.XAdd(ctx, &redis.XAddArgs{ Stream: p.cfg.RedisStream, - Values: map[string]any{messageKey: value.Marshal()}, + Values: map[string]any{messageKey: p.mReq.Marshal(value)}, }).Result() if err != nil { return nil, fmt.Errorf("adding values to redis: %w", err) @@ -275,8 +278,7 @@ func (p *Producer[Request, Response]) checkPending(ctx context.Context) ([]*Mess if !ok { return nil, fmt.Errorf("casting request: %v to bytes", msg.Values[messageKey]) } - var tmp Request - val, err := tmp.Unmarshal([]byte(data)) + val, err := p.mReq.Unmarshal([]byte(data)) if err != nil { return nil, fmt.Errorf("marshaling value: %v, error: %w", msg.Values[messageKey], err) } diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index c980ff29a9..095d59db3b 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -21,32 +21,24 @@ var ( messagesCount = 100 ) -type testRequest struct { - request string -} +type testRequestMarshaller struct{} -func (r *testRequest) Marshal() []byte { - return []byte(r.request) +func (t *testRequestMarshaller) Marshal(val string) []byte { + return []byte(val) } -func (r *testRequest) Unmarshal(val []byte) (*testRequest, error) { - return &testRequest{ - request: string(val), - }, nil +func (t *testRequestMarshaller) Unmarshal(val []byte) (string, error) { + return string(val), nil } -type testResponse struct { - response string -} +type testResponseMarshaller struct{} -func (r *testResponse) Marshal() []byte { - return []byte(r.response) +func (t *testResponseMarshaller) Marshal(val string) []byte { + return []byte(val) } -func (r *testResponse) Unmarshal(val []byte) (*testResponse, error) { - return &testResponse{ - response: string(val), - }, nil +func (t *testResponseMarshaller) Unmarshal(val []byte) (string, error) { + return string(val), nil } func createGroup(ctx context.Context, t *testing.T, client redis.UniversalClient) { @@ -67,7 +59,7 @@ func (e *disableReproduce) apply(_ *ConsumerConfig, prodCfg *ProducerConfig) { prodCfg.EnableReproduce = false } -func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[*testRequest, *testResponse], []*Consumer[*testRequest, *testResponse]) { +func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) (*Producer[string, string], []*Consumer[string, string]) { t.Helper() redisURL := redisutil.CreateTestRedis(ctx, t) prodCfg, consCfg := DefaultTestProducerConfig, DefaultTestConsumerConfig @@ -75,14 +67,14 @@ func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt) for _, o := range opts { o.apply(consCfg, prodCfg) } - producer, err := NewProducer[*testRequest, *testResponse](prodCfg) + producer, err := NewProducer[string, string](prodCfg, &testRequestMarshaller{}, &testResponseMarshaller{}) if err != nil { t.Fatalf("Error creating new producer: %v", err) } - var consumers []*Consumer[*testRequest, *testResponse] + var consumers []*Consumer[string, string] for i := 0; i < consumersCount; i++ { - c, err := NewConsumer[*testRequest, *testResponse](ctx, consCfg) + c, err := NewConsumer[string, string](ctx, consCfg, &testRequestMarshaller{}, &testResponseMarshaller{}) if err != nil { t.Fatalf("Error creating new consumer: %v", err) } @@ -133,12 +125,12 @@ func TestRedisProduce(t *testing.T) { if res == nil { continue } - gotMessages[idx][res.ID] = res.Value.request - resp := &testResponse{response: fmt.Sprintf("result for: %v", res.ID)} + gotMessages[idx][res.ID] = res.Value + resp := fmt.Sprintf("result for: %v", res.ID) if err := c.SetResult(ctx, res.ID, resp); err != nil { t.Errorf("Error setting a result: %v", err) } - wantResponses[idx] = append(wantResponses[idx], resp.response) + wantResponses[idx] = append(wantResponses[idx], resp) } }) } @@ -146,7 +138,7 @@ func TestRedisProduce(t *testing.T) { var gotResponses []string for i := 0; i < messagesCount; i++ { - value := &testRequest{request: fmt.Sprintf("msg: %d", i)} + value := fmt.Sprintf("msg: %d", i) p, err := producer.Produce(ctx, value) if err != nil { t.Errorf("Produce() unexpected error: %v", err) @@ -155,7 +147,7 @@ func TestRedisProduce(t *testing.T) { if err != nil { t.Errorf("Await() unexpected error: %v", err) } - gotResponses = append(gotResponses, res.response) + gotResponses = append(gotResponses, res) } producer.StopWaiter.StopAndWait() @@ -192,10 +184,10 @@ func flatten(responses [][]string) []string { return ret } -func produceMessages(ctx context.Context, producer *Producer[*testRequest, *testResponse]) ([]*containers.Promise[*testResponse], error) { - var promises []*containers.Promise[*testResponse] +func produceMessages(ctx context.Context, producer *Producer[string, string]) ([]*containers.Promise[string], error) { + var promises []*containers.Promise[string] for i := 0; i < messagesCount; i++ { - value := &testRequest{request: fmt.Sprintf("msg: %d", i)} + value := fmt.Sprintf("msg: %d", i) promise, err := producer.Produce(ctx, value) if err != nil { return nil, err @@ -205,7 +197,7 @@ func produceMessages(ctx context.Context, producer *Producer[*testRequest, *test return promises, nil } -func awaitResponses(ctx context.Context, promises []*containers.Promise[*testResponse]) ([]string, error) { +func awaitResponses(ctx context.Context, promises []*containers.Promise[string]) ([]string, error) { var ( responses []string errs []error @@ -216,12 +208,12 @@ func awaitResponses(ctx context.Context, promises []*containers.Promise[*testRes errs = append(errs, err) continue } - responses = append(responses, res.response) + responses = append(responses, res) } return responses, errors.Join(errs...) } -func consume(ctx context.Context, t *testing.T, consumers []*Consumer[*testRequest, *testResponse], skipN int) ([]map[string]string, [][]string) { +func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, string], skipN int) ([]map[string]string, [][]string) { t.Helper() gotMessages := messagesMaps(consumersCount) wantResponses := make([][]string, consumersCount) @@ -246,12 +238,12 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[*testReque if res == nil { continue } - gotMessages[idx][res.ID] = res.Value.request - resp := &testResponse{response: fmt.Sprintf("result for: %v", res.ID)} + gotMessages[idx][res.ID] = res.Value + resp := fmt.Sprintf("result for: %v", res.ID) if err := c.SetResult(ctx, res.ID, resp); err != nil { t.Errorf("Error setting a result: %v", err) } - wantResponses[idx] = append(wantResponses[idx], resp.response) + wantResponses[idx] = append(wantResponses[idx], resp) } }) }