From b75b4acd9c600abcd97b1c925b218f58a8b645eb Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Mon, 17 Jun 2024 16:36:11 -0700 Subject: [PATCH] Add `batch.BySize` function (#735) --- lib/batch/batch.go | 48 ++++++++++++++++++++ lib/batch/batch_test.go | 98 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 lib/batch/batch.go create mode 100644 lib/batch/batch_test.go diff --git a/lib/batch/batch.go b/lib/batch/batch.go new file mode 100644 index 000000000..3c8b6de82 --- /dev/null +++ b/lib/batch/batch.go @@ -0,0 +1,48 @@ +package batch + +import "fmt" + +// BySize takes a series of elements [in], encodes them using [encode], groups them into batches of bytes that sum to at +// most [maxSizeBytes], and then passes each batch to the [yield] function. +func BySize[T any](in []T, maxSizeBytes int, encode func(T) ([]byte, error), yield func([][]byte) error) error { + var buffer [][]byte + var currentSizeBytes int + + for i, item := range in { + bytes, err := encode(item) + if err != nil { + return fmt.Errorf("failed to encode item %d: %w", i, err) + } + + if len(bytes) > maxSizeBytes { + return fmt.Errorf("item %d is larger (%d bytes) than maxSizeBytes (%d bytes)", i, len(bytes), maxSizeBytes) + } + + currentSizeBytes += len(bytes) + + if currentSizeBytes < maxSizeBytes { + buffer = append(buffer, bytes) + } else if currentSizeBytes == maxSizeBytes { + buffer = append(buffer, bytes) + if err := yield(buffer); err != nil { + return err + } + buffer = [][]byte{} + currentSizeBytes = 0 + } else { + if err := yield(buffer); err != nil { + return err + } + buffer = [][]byte{bytes} + currentSizeBytes = len(bytes) + } + } + + if len(buffer) > 0 { + if err := yield(buffer); err != nil { + return err + } + } + + return nil +} diff --git a/lib/batch/batch_test.go b/lib/batch/batch_test.go new file mode 100644 index 000000000..078062521 --- /dev/null +++ b/lib/batch/batch_test.go @@ -0,0 +1,98 @@ +package batch + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBySize(t *testing.T) { + goodEncoder := func(value string) ([]byte, error) { + return []byte(value), nil + } + + panicEncoder := func(value string) ([]byte, error) { + panic("should not be called") + } + + badEncoder := func(value string) ([]byte, error) { + return nil, fmt.Errorf("failed to encode %q", value) + } + + testBySize := func(in []string, maxSizeBytes int, encoder func(value string) ([]byte, error)) ([][][]byte, error) { + batches := [][][]byte{} + err := BySize(in, maxSizeBytes, encoder, func(batch [][]byte) error { batches = append(batches, batch); return nil }) + return batches, err + } + + badYield := func(batch [][]byte) error { + out := make([]string, len(batch)) + for i, bytes := range batch { + out[i] = string(bytes) + } + return fmt.Errorf("yield failed for %v", out) + } + + { + // Empty slice: + batches, err := testBySize([]string{}, 10, panicEncoder) + assert.NoError(t, err) + assert.Empty(t, batches) + } + { + // Non-empty slice + bad encoder: + _, err := testBySize([]string{"foo", "bar"}, 10, badEncoder) + assert.ErrorContains(t, err, `failed to encode item 0: failed to encode "foo"`) + } + { + // Non-empty slice + two items that are < maxSizeBytes + yield returns error. + err := BySize([]string{"foo", "bar"}, 10, goodEncoder, badYield) + assert.ErrorContains(t, err, "yield failed for [foo bar]") + } + { + // Non-empty slice + two items that are = maxSizeBytes + yield returns error. + err := BySize([]string{"foo", "bar"}, 6, goodEncoder, badYield) + assert.ErrorContains(t, err, "yield failed for [foo bar]") + } + { + // Non-empty slice + two items that are > maxSizeBytes + yield returns error. + err := BySize([]string{"foo", "bar-baz"}, 8, goodEncoder, badYield) + assert.ErrorContains(t, err, "yield failed for [foo]") + } + { + // Non-empty slice + item is larger than maxSizeBytes: + _, err := testBySize([]string{"foo", "i-am-23-characters-long", "bar"}, 20, goodEncoder) + assert.ErrorContains(t, err, "item 1 is larger (23 bytes) than maxSizeBytes (20 bytes)") + } + { + // Non-empty slice + item equal to maxSizeBytes: + batches, err := testBySize([]string{"foo", "i-am-23-characters-long", "bar"}, 23, goodEncoder) + assert.NoError(t, err) + assert.Len(t, batches, 3) + assert.Equal(t, [][]byte{[]byte("foo")}, batches[0]) + assert.Equal(t, [][]byte{[]byte("i-am-23-characters-long")}, batches[1]) + assert.Equal(t, [][]byte{[]byte("bar")}, batches[2]) + } + { + // Non-empty slice + one item: + batches, err := testBySize([]string{"foo"}, 100, goodEncoder) + assert.NoError(t, err) + assert.Len(t, batches, 1) + assert.Equal(t, [][]byte{[]byte("foo")}, batches[0]) + } + { + // Non-empty slice + all items exactly fit into one batch: + batches, err := testBySize([]string{"foo", "bar", "baz", "qux"}, 12, goodEncoder) + assert.NoError(t, err) + assert.Len(t, batches, 1) + assert.Equal(t, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("qux")}, batches[0]) + } + { + // Non-empty slice + all items exactly fit into just under one batch: + batches, err := testBySize([]string{"foo", "bar", "baz", "qux"}, 13, goodEncoder) + assert.NoError(t, err) + assert.Len(t, batches, 1) + assert.Equal(t, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("qux")}, batches[0]) + } +}