Skip to content

Commit

Permalink
Add batch.BySize function (#735)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Jun 17, 2024
1 parent 9f1c5e2 commit b75b4ac
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
48 changes: 48 additions & 0 deletions lib/batch/batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package batch

import "fmt"

// BySize takes a series of elements [in], encodes them using [encode], groups them into batches of bytes that sum to at
// most [maxSizeBytes], and then passes each batch to the [yield] function.
func BySize[T any](in []T, maxSizeBytes int, encode func(T) ([]byte, error), yield func([][]byte) error) error {
var buffer [][]byte
var currentSizeBytes int

for i, item := range in {
bytes, err := encode(item)
if err != nil {
return fmt.Errorf("failed to encode item %d: %w", i, err)
}

if len(bytes) > maxSizeBytes {
return fmt.Errorf("item %d is larger (%d bytes) than maxSizeBytes (%d bytes)", i, len(bytes), maxSizeBytes)
}

currentSizeBytes += len(bytes)

if currentSizeBytes < maxSizeBytes {
buffer = append(buffer, bytes)
} else if currentSizeBytes == maxSizeBytes {
buffer = append(buffer, bytes)
if err := yield(buffer); err != nil {
return err
}
buffer = [][]byte{}
currentSizeBytes = 0
} else {
if err := yield(buffer); err != nil {
return err
}
buffer = [][]byte{bytes}
currentSizeBytes = len(bytes)
}
}

if len(buffer) > 0 {
if err := yield(buffer); err != nil {
return err
}
}

return nil
}
98 changes: 98 additions & 0 deletions lib/batch/batch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package batch

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestBySize(t *testing.T) {
goodEncoder := func(value string) ([]byte, error) {
return []byte(value), nil
}

panicEncoder := func(value string) ([]byte, error) {
panic("should not be called")
}

badEncoder := func(value string) ([]byte, error) {
return nil, fmt.Errorf("failed to encode %q", value)
}

testBySize := func(in []string, maxSizeBytes int, encoder func(value string) ([]byte, error)) ([][][]byte, error) {
batches := [][][]byte{}
err := BySize(in, maxSizeBytes, encoder, func(batch [][]byte) error { batches = append(batches, batch); return nil })
return batches, err
}

badYield := func(batch [][]byte) error {
out := make([]string, len(batch))
for i, bytes := range batch {
out[i] = string(bytes)
}
return fmt.Errorf("yield failed for %v", out)
}

{
// Empty slice:
batches, err := testBySize([]string{}, 10, panicEncoder)
assert.NoError(t, err)
assert.Empty(t, batches)
}
{
// Non-empty slice + bad encoder:
_, err := testBySize([]string{"foo", "bar"}, 10, badEncoder)
assert.ErrorContains(t, err, `failed to encode item 0: failed to encode "foo"`)
}
{
// Non-empty slice + two items that are < maxSizeBytes + yield returns error.
err := BySize([]string{"foo", "bar"}, 10, goodEncoder, badYield)
assert.ErrorContains(t, err, "yield failed for [foo bar]")
}
{
// Non-empty slice + two items that are = maxSizeBytes + yield returns error.
err := BySize([]string{"foo", "bar"}, 6, goodEncoder, badYield)
assert.ErrorContains(t, err, "yield failed for [foo bar]")
}
{
// Non-empty slice + two items that are > maxSizeBytes + yield returns error.
err := BySize([]string{"foo", "bar-baz"}, 8, goodEncoder, badYield)
assert.ErrorContains(t, err, "yield failed for [foo]")
}
{
// Non-empty slice + item is larger than maxSizeBytes:
_, err := testBySize([]string{"foo", "i-am-23-characters-long", "bar"}, 20, goodEncoder)
assert.ErrorContains(t, err, "item 1 is larger (23 bytes) than maxSizeBytes (20 bytes)")
}
{
// Non-empty slice + item equal to maxSizeBytes:
batches, err := testBySize([]string{"foo", "i-am-23-characters-long", "bar"}, 23, goodEncoder)
assert.NoError(t, err)
assert.Len(t, batches, 3)
assert.Equal(t, [][]byte{[]byte("foo")}, batches[0])
assert.Equal(t, [][]byte{[]byte("i-am-23-characters-long")}, batches[1])
assert.Equal(t, [][]byte{[]byte("bar")}, batches[2])
}
{
// Non-empty slice + one item:
batches, err := testBySize([]string{"foo"}, 100, goodEncoder)
assert.NoError(t, err)
assert.Len(t, batches, 1)
assert.Equal(t, [][]byte{[]byte("foo")}, batches[0])
}
{
// Non-empty slice + all items exactly fit into one batch:
batches, err := testBySize([]string{"foo", "bar", "baz", "qux"}, 12, goodEncoder)
assert.NoError(t, err)
assert.Len(t, batches, 1)
assert.Equal(t, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("qux")}, batches[0])
}
{
// Non-empty slice + all items exactly fit into just under one batch:
batches, err := testBySize([]string{"foo", "bar", "baz", "qux"}, 13, goodEncoder)
assert.NoError(t, err)
assert.Len(t, batches, 1)
assert.Equal(t, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("qux")}, batches[0])
}
}

0 comments on commit b75b4ac

Please sign in to comment.