Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anodar committed Apr 9, 2024
1 parent 5c52884 commit 590ec7b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 138 deletions.
19 changes: 16 additions & 3 deletions pubsub/producer.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
// Package pubsub implements publisher/subscriber model (one to many).
// During normal operation, publisher returns "Promise" when publishing a
// message, which will return resposne from consumer when awaited.
// If the consumer processing the request becomes inactive, message is
// re-inserted (if EnableReproduce flag is enabled), and will be picked up by
// another consumer.
// We are assuming here that keeepAliveTimeout is set to some sensible value
// and once consumer becomes inactive, it doesn't activate without restart.
package pubsub

import (
Expand Down Expand Up @@ -37,7 +45,7 @@ type Producer[Request any, Response any] struct {
promisesLock sync.RWMutex
promises map[string]*containers.Promise[Response]

// Used for running checks for pending messages with inactive consumers
// Used for running checks for pending messages with inactive consumers
// and checking responses from consumers iteratively for the first time when
// Produce is called.
once sync.Once
Expand Down Expand Up @@ -112,8 +120,10 @@ func (p *Producer[Request, Response]) errorPromisesFor(msgs []*Message[Request])
p.promisesLock.Lock()
defer p.promisesLock.Unlock()
for _, msg := range msgs {
p.promises[msg.ID].ProduceError(fmt.Errorf("internal error, consumer died while serving the request"))
delete(p.promises, msg.ID)
if msg != nil {
p.promises[msg.ID].ProduceError(fmt.Errorf("internal error, consumer died while serving the request"))
delete(p.promises, msg.ID)
}
}
}

Expand Down Expand Up @@ -197,6 +207,9 @@ func (p *Producer[Request, Response]) reproduce(ctx context.Context, value Reque
p.promisesLock.Lock()
defer p.promisesLock.Unlock()
promise := p.promises[oldKey]
if oldKey != "" && promise == nil {
return nil, fmt.Errorf("errror reproducing the message, could not find existing one")
}
if oldKey == "" || promise == nil {
pr := containers.NewPromise[Response](nil)
promise = &pr
Expand Down
213 changes: 78 additions & 135 deletions pubsub/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"sort"
"testing"

"github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
Expand Down Expand Up @@ -41,17 +42,15 @@ func (t *testResponseMarshaller) Unmarshal(val []byte) (string, error) {

func createGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) {
t.Helper()
_, err := client.XGroupCreateMkStream(ctx, streamName, groupName, "$").Result()
if err != nil {
if _, err := client.XGroupCreateMkStream(ctx, streamName, groupName, "$").Result(); err != nil {
t.Fatalf("Error creating stream group: %v", err)
}
}

func destroyGroup(ctx context.Context, t *testing.T, streamName, groupName string, client redis.UniversalClient) {
t.Helper()
_, err := client.XGroupDestroy(ctx, streamName, groupName).Result()
if err != nil {
t.Fatalf("Error creating stream group: %v", err)
if _, err := client.XGroupDestroy(ctx, streamName, groupName).Result(); err != nil {
log.Debug("Error creating stream group: %v", err)
}
}

Expand Down Expand Up @@ -108,13 +107,14 @@ func newProducerConsumers(ctx context.Context, t *testing.T, opts ...configOpt)
}
createGroup(ctx, t, streamName, groupName, producer.client)
t.Cleanup(func() {
ctx := context.Background()
destroyGroup(ctx, t, streamName, groupName, producer.client)
var keys []string
for _, c := range consumers {
keys = append(keys, c.heartBeatKey())
}
if _, err := producer.client.Del(ctx, keys...).Result(); err != nil {
t.Fatalf("Error deleting heartbeat keys: %v\n", err)
log.Debug("Error deleting heartbeat keys", "error", err)
}
})
return producer, consumers
Expand All @@ -133,99 +133,23 @@ func wantMessages(n int) []string {
for i := 0; i < n; i++ {
ret = append(ret, fmt.Sprintf("msg: %d", i))
}
sort.Slice(ret, func(i, j int) bool {
return fmt.Sprintf("%v", ret[i]) < fmt.Sprintf("%v", ret[j])
})
sort.Strings(ret)
return ret
}

func TestRedisProduce(t *testing.T) {
t.Parallel()
ctx := context.Background()
producer, consumers := newProducerConsumers(ctx, t)
producer.Start(ctx)
gotMessages := messagesMaps(consumersCount)
wantResponses := make([][]string, len(consumers))
for idx, c := range consumers {
idx, c := idx, c
c.Start(ctx)
c.StopWaiter.LaunchThread(
func(ctx context.Context) {
for {
res, err := c.Consume(ctx)
if err != nil {
if !errors.Is(err, context.Canceled) {
t.Errorf("Consume() unexpected error: %v", err)
}
return
}
if res == nil {
continue
}
gotMessages[idx][res.ID] = res.Value
resp := fmt.Sprintf("result for: %v", res.ID)
if err := c.SetResult(ctx, res.ID, resp); err != nil {
t.Errorf("Error setting a result: %v", err)
}
wantResponses[idx] = append(wantResponses[idx], resp)
}
})
}

var gotResponses []string

for i := 0; i < messagesCount; i++ {
value := fmt.Sprintf("msg: %d", i)
p, err := producer.Produce(ctx, value)
if err != nil {
t.Errorf("Produce() unexpected error: %v", err)
}
res, err := p.Await(ctx)
if err != nil {
t.Errorf("Await() unexpected error: %v", err)
}
gotResponses = append(gotResponses, res)
}

producer.StopWaiter.StopAndWait()
for _, c := range consumers {
c.StopAndWait()
}

got, err := mergeValues(gotMessages)
if err != nil {
t.Fatalf("mergeMaps() unexpected error: %v", err)
}
want := wantMessages(messagesCount)
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Unexpected diff (-want +got):\n%s\n", diff)
}

wantResp := flatten(wantResponses)
sort.Slice(gotResponses, func(i, j int) bool {
return gotResponses[i] < gotResponses[j]
})
if diff := cmp.Diff(wantResp, gotResponses); diff != "" {
t.Errorf("Unexpected diff in responses:\n%s\n", diff)
}
}

func flatten(responses [][]string) []string {
var ret []string
for _, v := range responses {
ret = append(ret, v...)
}
sort.Slice(ret, func(i, j int) bool {
return ret[i] < ret[j]
})
sort.Strings(ret)
return ret
}

func produceMessages(ctx context.Context, producer *Producer[string, string]) ([]*containers.Promise[string], error) {
func produceMessages(ctx context.Context, msgs []string, producer *Producer[string, string]) ([]*containers.Promise[string], error) {
var promises []*containers.Promise[string]
for i := 0; i < messagesCount; i++ {
value := fmt.Sprintf("msg: %d", i)
promise, err := producer.Produce(ctx, value)
promise, err := producer.Produce(ctx, msgs[i])
if err != nil {
return nil, err
}
Expand All @@ -250,13 +174,13 @@ func awaitResponses(ctx context.Context, promises []*containers.Promise[string])
return responses, errors.Join(errs...)
}

// consume messages from every consumer except every skipNth.
func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, string], skipN int) ([]map[string]string, [][]string) {
// consume messages from every consumer except stopped ones.
func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, string]) ([]map[string]string, [][]string) {
t.Helper()
gotMessages := messagesMaps(consumersCount)
wantResponses := make([][]string, consumersCount)
for idx := 0; idx < consumersCount; idx++ {
if idx%skipN == 0 {
if consumers[idx].Stopped() {
continue
}
idx, c := idx, consumers[idx]
Expand Down Expand Up @@ -288,73 +212,92 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[string, st
return gotMessages, wantResponses
}

func TestRedisClaimingOwnership(t *testing.T) {
func TestRedisProduce(t *testing.T) {
t.Parallel()
ctx := context.Background()
producer, consumers := newProducerConsumers(ctx, t)
producer.Start(ctx)
promises, err := produceMessages(ctx, producer)
if err != nil {
t.Fatalf("Error producing messages: %v", err)
}
for _, tc := range []struct {
name string
killConsumers bool
}{
{
name: "all consumers are active",
killConsumers: false,
},
{
name: "some consumers killed, others should take over their work",
killConsumers: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
producer, consumers := newProducerConsumers(ctx, t)
producer.Start(ctx)
wantMsgs := wantMessages(messagesCount)
promises, err := produceMessages(ctx, wantMsgs, producer)
if err != nil {
t.Fatalf("Error producing messages: %v", err)
}
if tc.killConsumers {
// Consumer messages in every third consumer but don't ack them to check
// that other consumers will claim ownership on those messages.
for i := 0; i < len(consumers); i += 3 {
if _, err := consumers[i].Consume(ctx); err != nil {
t.Errorf("Error consuming message: %v", err)
}
consumers[i].StopAndWait()
}

// Consumer messages in every third consumer but don't ack them to check
// that other consumers will claim ownership on those messages.
for i := 0; i < len(consumers); i += 3 {
i := i
if _, err := consumers[i].Consume(ctx); err != nil {
t.Errorf("Error consuming message: %v", err)
}
consumers[i].StopAndWait()
}
}
gotMessages, wantResponses := consume(ctx, t, consumers)
gotResponses, err := awaitResponses(ctx, promises)
if err != nil {
t.Fatalf("Error awaiting responses: %v", err)
}
for _, c := range consumers {
c.StopWaiter.StopAndWait()
}
got, err := mergeValues(gotMessages)
if err != nil {
t.Fatalf("mergeMaps() unexpected error: %v", err)
}

gotMessages, wantResponses := consume(ctx, t, consumers, 3)
gotResponses, err := awaitResponses(ctx, promises)
if err != nil {
t.Fatalf("Error awaiting responses: %v", err)
}
for _, c := range consumers {
c.StopWaiter.StopAndWait()
}
got, err := mergeValues(gotMessages)
if err != nil {
t.Fatalf("mergeMaps() unexpected error: %v", err)
}
want := wantMessages(messagesCount)
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Unexpected diff (-want +got):\n%s\n", diff)
}
wantResp := flatten(wantResponses)
sort.Strings(gotResponses)
if diff := cmp.Diff(wantResp, gotResponses); diff != "" {
t.Errorf("Unexpected diff in responses:\n%s\n", diff)
}
if cnt := producer.promisesLen(); cnt != 0 {
t.Errorf("Producer still has %d unfullfilled promises", cnt)
if diff := cmp.Diff(wantMsgs, got); diff != "" {
t.Errorf("Unexpected diff (-want +got):\n%s\n", diff)
}
wantResp := flatten(wantResponses)
sort.Strings(gotResponses)
if diff := cmp.Diff(wantResp, gotResponses); diff != "" {
t.Errorf("Unexpected diff in responses:\n%s\n", diff)
}
if cnt := producer.promisesLen(); cnt != 0 {
t.Errorf("Producer still has %d unfullfilled promises", cnt)
}
})
}
}

func TestRedisClaimingOwnershipReproduceDisabled(t *testing.T) {
func TestRedisReproduceDisabled(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
producer, consumers := newProducerConsumers(ctx, t, &disableReproduce{})
producer.Start(ctx)
promises, err := produceMessages(ctx, producer)
wantMsgs := wantMessages(messagesCount)
promises, err := produceMessages(ctx, wantMsgs, producer)
if err != nil {
t.Fatalf("Error producing messages: %v", err)
}

// Consumer messages in every third consumer but don't ack them to check
// that other consumers will claim ownership on those messages.
for i := 0; i < len(consumers); i += 3 {
i := i
if _, err := consumers[i].Consume(ctx); err != nil {
t.Errorf("Error consuming message: %v", err)
}
consumers[i].StopAndWait()
}

gotMessages, _ := consume(ctx, t, consumers, 3)
gotMessages, _ := consume(ctx, t, consumers)
gotResponses, err := awaitResponses(ctx, promises)
if err == nil {
t.Fatalf("All promises were fullfilled with reproduce disabled and some consumers killed")
Expand All @@ -366,7 +309,7 @@ func TestRedisClaimingOwnershipReproduceDisabled(t *testing.T) {
if err != nil {
t.Fatalf("mergeMaps() unexpected error: %v", err)
}
wantMsgCnt := messagesCount - (consumersCount / 3) - (consumersCount % 3)
wantMsgCnt := messagesCount - ((consumersCount + 2) / 3)
if len(got) != wantMsgCnt {
t.Fatalf("Got: %d messages, want %d", len(got), wantMsgCnt)
}
Expand Down

0 comments on commit 590ec7b

Please sign in to comment.