diff --git a/zstd_bulk.go b/zstd_bulk.go new file mode 100644 index 0000000..fadfe5d --- /dev/null +++ b/zstd_bulk.go @@ -0,0 +1,156 @@ +package zstd + +/* +#include "zstd.h" +*/ +import "C" +import ( + "errors" + "runtime" + "unsafe" +) + +var ( + // ErrEmptyDictionary is returned when the given dictionary is empty + ErrEmptyDictionary = errors.New("Dictionary is empty") + // ErrBadDictionary is returned when cannot load the given dictionary + ErrBadDictionary = errors.New("Cannot load dictionary") + // ErrContentSize is returned when cannot determine the content size + ErrContentSize = errors.New("Cannot determine the content size") +) + +// BulkProcessor implements Bulk processing dictionary API. +// When compressing multiple messages or blocks using the same dictionary, +// it's recommended to digest the dictionary only once, since it's a costly operation. +// NewBulkProcessor() will create a state from digesting a dictionary. +// The resulting state can be used for future compression/decompression operations with very limited startup cost. +// BulkProcessor can be created once and shared by multiple threads concurrently, since its usage is read-only. +// The state will be freed when gc cleans up BulkProcessor. +type BulkProcessor struct { + cDict *C.struct_ZSTD_CDict_s + dDict *C.struct_ZSTD_DDict_s +} + +// NewBulkProcessor creates a new BulkProcessor with a pre-trained dictionary and compression level +func NewBulkProcessor(dictionary []byte, compressionLevel int) (*BulkProcessor, error) { + if len(dictionary) < 1 { + return nil, ErrEmptyDictionary + } + + p := &BulkProcessor{} + runtime.SetFinalizer(p, finalizeBulkProcessor) + + p.cDict = C.ZSTD_createCDict( + unsafe.Pointer(&dictionary[0]), + C.size_t(len(dictionary)), + C.int(compressionLevel), + ) + if p.cDict == nil { + return nil, ErrBadDictionary + } + p.dDict = C.ZSTD_createDDict( + unsafe.Pointer(&dictionary[0]), + C.size_t(len(dictionary)), + ) + if p.dDict == nil { + return nil, ErrBadDictionary + } + + return p, nil +} + +// Compress compresses `src` into `dst` with the dictionary given when creating the BulkProcessor. +// If you have a buffer to use, you can pass it to prevent allocation. +// If it is too small, or if nil is passed, a new buffer will be allocated and returned. +func (p *BulkProcessor) Compress(dst, src []byte) ([]byte, error) { + bound := CompressBound(len(src)) + if cap(dst) >= bound { + dst = dst[0:bound] + } else { + dst = make([]byte, bound) + } + + cctx := C.ZSTD_createCCtx() + // We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics. + // This means we need to special case empty input. See: + // https://github.com/golang/go/issues/14210#issuecomment-346402945 + var cWritten C.size_t + if len(src) == 0 { + cWritten = C.ZSTD_compress_usingCDict( + cctx, + unsafe.Pointer(&dst[0]), + C.size_t(len(dst)), + unsafe.Pointer(nil), + C.size_t(len(src)), + p.cDict, + ) + } else { + cWritten = C.ZSTD_compress_usingCDict( + cctx, + unsafe.Pointer(&dst[0]), + C.size_t(len(dst)), + unsafe.Pointer(&src[0]), + C.size_t(len(src)), + p.cDict, + ) + } + + C.ZSTD_freeCCtx(cctx) + + written := int(cWritten) + if err := getError(written); err != nil { + return nil, err + } + return dst[:written], nil +} + +// Decompress decompresses `src` into `dst` with the dictionary given when creating the BulkProcessor. +// If you have a buffer to use, you can pass it to prevent allocation. +// If it is too small, or if nil is passed, a new buffer will be allocated and returned. +func (p *BulkProcessor) Decompress(dst, src []byte) ([]byte, error) { + if len(src) == 0 { + return nil, ErrEmptySlice + } + contentSize := uint64(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))) + if contentSize == C.ZSTD_CONTENTSIZE_ERROR || contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN { + return nil, ErrContentSize + } + + if cap(dst) >= int(contentSize) { + dst = dst[0:contentSize] + } else { + dst = make([]byte, contentSize) + } + + if contentSize == 0 { + return dst, nil + } + + dctx := C.ZSTD_createDCtx() + cWritten := C.ZSTD_decompress_usingDDict( + dctx, + unsafe.Pointer(&dst[0]), + C.size_t(contentSize), + unsafe.Pointer(&src[0]), + C.size_t(len(src)), + p.dDict, + ) + C.ZSTD_freeDCtx(dctx) + + written := int(cWritten) + if err := getError(written); err != nil { + return nil, err + } + + return dst[:written], nil +} + +// finalizeBulkProcessor frees compression and decompression dictionaries from memory +func finalizeBulkProcessor(p *BulkProcessor) { + if p.cDict != nil { + C.ZSTD_freeCDict(p.cDict) + } + if p.dDict != nil { + C.ZSTD_freeDDict(p.dDict) + } +} diff --git a/zstd_bullk_test.go b/zstd_bullk_test.go new file mode 100644 index 0000000..eeba156 --- /dev/null +++ b/zstd_bullk_test.go @@ -0,0 +1,244 @@ +package zstd + +import ( + "bytes" + "encoding/base64" + "math/rand" + "regexp" + "strings" + "testing" +) + +var dictBase64 string = ` + N6Qw7IsuFDIdENCSQjr//////4+QlekuNkmXbUBIkIDiVRX7H4AzAFCgQCFCO9oHAAAEQEuSikaK + Dg51OYghBYgBAAAAAAAAAAAAAAAAAAAAANQVpmRQGQAAAAAAAAAAAAAAAAABAAAABAAAAAgAAABo + ZWxwIEpvaW4gZW5naW5lZXJzIGVuZ2luZWVycyBmdXR1cmUgbG92ZSB0aGF0IGFyZWlsZGluZyB1 + c2UgaGVscCBoZWxwIHVzaGVyIEpvaW4gdXNlIGxvdmUgdXMgSm9pbiB1bmQgaW4gdXNoZXIgdXNo + ZXIgYSBwbGF0Zm9ybSB1c2UgYW5kIGZ1dHVyZQ==` +var dict []byte +var compressedPayload []byte + +func init() { + var err error + dict, err = base64.StdEncoding.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, "")) + if err != nil { + panic("failed to create dictionary") + } + p, err := NewBulkProcessor(dict, BestSpeed) + if err != nil { + panic("failed to create bulk processor") + } + compressedPayload, err = p.Compress(nil, []byte("We're building a platform that engineers love to use. Join us, and help usher in the future.")) + if err != nil { + panic("failed to compress payload") + } +} + +func newBulkProcessor(t testing.TB, dict []byte, level int) *BulkProcessor { + p, err := NewBulkProcessor(dict, level) + if err != nil { + t.Fatal("failed to create a BulkProcessor") + } + return p +} + +func getRandomText() string { + words := []string{"We", "are", "building", "a platform", "that", "engineers", "love", "to", "use", "Join", "us", "and", "help", "usher", "in", "the", "future"} + wordCount := 10 + rand.Intn(100) // 10 - 109 + result := []string{} + for i := 0; i < wordCount; i++ { + result = append(result, words[rand.Intn(len(words))]) + } + + return strings.Join(result, " ") +} + +func TestBulkDictionary(t *testing.T) { + if len(dict) < 1 { + t.Error("dictionary is empty") + } +} + +func TestBulkCompressAndDecompress(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + for i := 0; i < 100; i++ { + payload := []byte(getRandomText()) + + compressed, err := p.Compress(nil, payload) + if err != nil { + t.Error("failed to compress") + } + + uncompressed, err := p.Decompress(nil, compressed) + if err != nil { + t.Error("failed to decompress") + } + + if bytes.Compare(payload, uncompressed) != 0 { + t.Error("uncompressed payload didn't match") + } + } +} + +func TestBulkEmptyOrNilDictionary(t *testing.T) { + p, err := NewBulkProcessor(nil, BestSpeed) + if p != nil { + t.Error("nil is expected") + } + if err != ErrEmptyDictionary { + t.Error("ErrEmptyDictionary is expected") + } + + p, err = NewBulkProcessor([]byte{}, BestSpeed) + if p != nil { + t.Error("nil is expected") + } + if err != ErrEmptyDictionary { + t.Error("ErrEmptyDictionary is expected") + } +} + +func TestBulkCompressEmptyOrNilContent(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + compressed, err := p.Compress(nil, nil) + if err != nil { + t.Error("failed to compress") + } + if len(compressed) < 4 { + t.Error("magic number doesn't exist") + } + + compressed, err = p.Compress(nil, []byte{}) + if err != nil { + t.Error("failed to compress") + } + if len(compressed) < 4 { + t.Error("magic number doesn't exist") + } +} + +func TestBulkCompressIntoGivenDestination(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + dst := make([]byte, 100000) + compressed, err := p.Compress(dst, []byte(getRandomText())) + if err != nil { + t.Error("failed to compress") + } + if len(compressed) < 4 { + t.Error("magic number doesn't exist") + } + if &dst[0] != &compressed[0] { + t.Error("'dst' and 'compressed' are not the same object") + } +} + +func TestBulkCompressNotEnoughDestination(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + dst := make([]byte, 1) + compressed, err := p.Compress(dst, []byte(getRandomText())) + if err != nil { + t.Error("failed to compress") + } + if len(compressed) < 4 { + t.Error("magic number doesn't exist") + } + if &dst[0] == &compressed[0] { + t.Error("'dst' and 'compressed' are the same object") + } +} + +func TestBulkDecompressIntoGivenDestination(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + dst := make([]byte, 100000) + decompressed, err := p.Decompress(dst, compressedPayload) + if err != nil { + t.Error("failed to decompress") + } + if &dst[0] != &decompressed[0] { + t.Error("'dst' and 'decompressed' are not the same object") + } +} + +func TestBulkDecompressNotEnoughDestination(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + dst := make([]byte, 1) + decompressed, err := p.Decompress(dst, compressedPayload) + if err != nil { + t.Error("failed to decompress") + } + if &dst[0] == &decompressed[0] { + t.Error("'dst' and 'decompressed' are the same object") + } +} + +func TestBulkDecompressEmptyOrNilContent(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + decompressed, err := p.Decompress(nil, nil) + if err != ErrEmptySlice { + t.Error("ErrEmptySlice is expected") + } + if decompressed != nil { + t.Error("nil is expected") + } + + decompressed, err = p.Decompress(nil, []byte{}) + if err != ErrEmptySlice { + t.Error("ErrEmptySlice is expected") + } + if decompressed != nil { + t.Error("nil is expected") + } +} + +func TestBulkCompressAndDecompressInReverseOrder(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + payloads := [][]byte{} + compressedPayloads := [][]byte{} + for i := 0; i < 100; i++ { + payloads = append(payloads, []byte(getRandomText())) + + compressed, err := p.Compress(nil, payloads[i]) + if err != nil { + t.Error("failed to compress") + } + compressedPayloads = append(compressedPayloads, compressed) + } + + for i := 99; i >= 0; i-- { + uncompressed, err := p.Decompress(nil, compressedPayloads[i]) + if err != nil { + t.Error("failed to decompress") + } + + if bytes.Compare(payloads[i], uncompressed) != 0 { + t.Error("uncompressed payload didn't match") + } + } +} + +// BenchmarkBulkCompress-8 780148 1505 ns/op 61.14 MB/s 208 B/op 5 allocs/op +func BenchmarkBulkCompress(b *testing.B) { + p := newBulkProcessor(b, dict, BestSpeed) + + payload := []byte("We're building a platform that engineers love to use. Join us, and help usher in the future.") + b.SetBytes(int64(len(payload))) + for n := 0; n < b.N; n++ { + _, err := p.Compress(nil, payload) + if err != nil { + b.Error("failed to compress") + } + } +} + +// BenchmarkBulkDecompress-8 817425 1412 ns/op 40.37 MB/s 192 B/op 7 allocs/op +func BenchmarkBulkDecompress(b *testing.B) { + p := newBulkProcessor(b, dict, BestSpeed) + + b.SetBytes(int64(len(compressedPayload))) + for n := 0; n < b.N; n++ { + _, err := p.Decompress(nil, compressedPayload) + if err != nil { + b.Error("failed to decompress") + } + } +}