Skip to content

Commit

Permalink
Add more tests and improve compact-u16
Browse files Browse the repository at this point in the history
  • Loading branch information
gagliardetto committed Apr 1, 2024
1 parent e3c7575 commit 79f49c5
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 20 deletions.
56 changes: 45 additions & 11 deletions compact-u16.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,66 +15,100 @@
package bin

import (
"fmt"
"io"
"math"
)

// EncodeCompactU16Length encodes a "Compact-u16" length into the provided slice pointer.
// See https://docs.solana.com/developing/programming-model/transactions#compact-u16-format
// See https://github.com/solana-labs/solana/blob/2ef2b6daa05a7cff057e9d3ef95134cee3e4045d/web3.js/src/util/shortvec-encoding.ts
func EncodeCompactU16Length(bytes *[]byte, ln int) {
func EncodeCompactU16Length(buf *[]byte, ln int) error {
if ln < 0 || ln > math.MaxUint16 {
return fmt.Errorf("length %d out of range", ln)
}
rem_len := ln
for {
elem := rem_len & 0x7f
elem := uint8(rem_len & 0x7f)
rem_len >>= 7
if rem_len == 0 {
*bytes = append(*bytes, byte(elem))
*buf = append(*buf, elem)
break
} else {
elem |= 0x80
*bytes = append(*bytes, byte(elem))
*buf = append(*buf, elem)
}
}
return nil
}

// DecodeCompactU16Length decodes a "Compact-u16" length from the provided byte slice.
func DecodeCompactU16Length(bytes []byte) int {
v, _, _ := DecodeCompactU16(bytes)
return v
}
const _MAX_COMPACTU16_ENCODING_LENGTH = 3

func DecodeCompactU16(bytes []byte) (int, int, error) {
ln := 0
size := 0
for {
for nth_byte := 0; nth_byte < _MAX_COMPACTU16_ENCODING_LENGTH; nth_byte++ {
if len(bytes) == 0 {
return 0, 0, io.ErrUnexpectedEOF
}
elem := int(bytes[0])
if elem == 0 && nth_byte != 0 {
return 0, 0, fmt.Errorf("alias")
}
if nth_byte >= _MAX_COMPACTU16_ENCODING_LENGTH {
return 0, 0, fmt.Errorf("too long: %d", nth_byte+1)
} else if nth_byte == _MAX_COMPACTU16_ENCODING_LENGTH-1 && (elem&0x80) != 0 {
return 0, 0, fmt.Errorf("byte three continues")
}
bytes = bytes[1:]
ln |= (elem & 0x7f) << (size * 7)
size += 1
if (elem & 0x80) == 0 {
break
}
}
// check for non-valid sizes
if size == 0 || size > _MAX_COMPACTU16_ENCODING_LENGTH {
return 0, 0, fmt.Errorf("invalid size: %d", size)
}
// check for non-valid lengths
if ln < 0 || ln > math.MaxUint16 {
return 0, 0, fmt.Errorf("invalid length: %d", ln)
}
return ln, size, nil
}

// DecodeCompactU16LengthFromByteReader decodes a "Compact-u16" length from the provided io.ByteReader.
func DecodeCompactU16LengthFromByteReader(reader io.ByteReader) (int, error) {
ln := 0
size := 0
for {
for nth_byte := 0; nth_byte < _MAX_COMPACTU16_ENCODING_LENGTH; nth_byte++ {
elemByte, err := reader.ReadByte()
if err != nil {
return 0, err
}
elem := int(elemByte)
if elem == 0 && nth_byte != 0 {
return 0, fmt.Errorf("alias")
}
if nth_byte >= _MAX_COMPACTU16_ENCODING_LENGTH {
return 0, fmt.Errorf("too long: %d", nth_byte+1)
} else if nth_byte == _MAX_COMPACTU16_ENCODING_LENGTH-1 && (elem&0x80) != 0 {
return 0, fmt.Errorf("byte three continues")
}
ln |= (elem & 0x7f) << (size * 7)
size += 1
if (elem & 0x80) == 0 {
break
}
}
// check for non-valid sizes
if size == 0 || size > _MAX_COMPACTU16_ENCODING_LENGTH {
return 0, fmt.Errorf("invalid size: %d", size)
}
// check for non-valid lengths
if ln < 0 || ln > math.MaxUint16 {
return 0, fmt.Errorf("invalid length: %d", ln)
}
return ln, nil
}
135 changes: 126 additions & 9 deletions compact-u16_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ import (
)

func TestCompactU16(t *testing.T) {
candidates := []int{3, 0x7f, 0x7f + 1, 0x3fff, 0x3fff + 1}
candidates := []int{0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 100, 1000, 10000, math.MaxUint16 - 1, math.MaxUint16}
for _, val := range candidates {
if val < 0 || val > math.MaxUint16 {
panic("value too large")
}
buf := make([]byte, 0)
EncodeCompactU16Length(&buf, val)
require.NoError(t, EncodeCompactU16Length(&buf, val))

buf = append(buf, []byte("hello world")...)
decoded := DecodeCompactU16Length(buf)
decoded, _, err := DecodeCompactU16(buf)
require.NoError(t, err)

require.Equal(t, val, decoded)
}
Expand All @@ -40,19 +44,34 @@ func TestCompactU16(t *testing.T) {
buf = append(buf, []byte("hello world")...)
{
decoded, err := DecodeCompactU16LengthFromByteReader(bytes.NewReader(buf))
if err != nil {
panic(err)
}
require.NoError(t, err)
require.Equal(t, val, decoded)
}
{
decoded, _, err := DecodeCompactU16(buf)
if err != nil {
panic(err)
}
require.NoError(t, err)
require.Equal(t, val, decoded)
}
}
{
// now test all from 0 to 0xffff
for i := 0; i < math.MaxUint16; i++ {
buf := make([]byte, 0)
EncodeCompactU16Length(&buf, i)

buf = append(buf, []byte("hello world")...)
{
decoded, err := DecodeCompactU16LengthFromByteReader(bytes.NewReader(buf))
require.NoError(t, err)
require.Equal(t, i, decoded)
}
{
decoded, _, err := DecodeCompactU16(buf)
require.NoError(t, err)
require.Equal(t, i, decoded)
}
}
}
}

func BenchmarkCompactU16(b *testing.B) {
Expand Down Expand Up @@ -102,3 +121,101 @@ func BenchmarkCompactU16Reader(b *testing.B) {
reader.SetPosition(0)
}
}

func encode_len(len uint16) []byte {
buf := make([]byte, 0)
err := EncodeCompactU16Length(&buf, int(len))
if err != nil {
panic(err)
}
return buf
}

func assert_len_encoding(t *testing.T, len uint16, buf []byte) {
require.Equal(t, encode_len(len), buf, "unexpected usize encoding")
decoded, _, err := DecodeCompactU16(buf)
require.NoError(t, err)
require.Equal(t, int(len), decoded)
{
// now try with a reader
reader := bytes.NewReader(buf)
out, _ := DecodeCompactU16LengthFromByteReader(reader)
require.Equal(t, int(len), out)
}
}

func TestShortVecEncodeLen(t *testing.T) {
assert_len_encoding(t, 0x0, []byte{0x0})
assert_len_encoding(t, 0x7f, []byte{0x7f})
assert_len_encoding(t, 0x80, []byte{0x80, 0x01})
assert_len_encoding(t, 0xff, []byte{0xff, 0x01})
assert_len_encoding(t, 0x100, []byte{0x80, 0x02})
assert_len_encoding(t, 0x7fff, []byte{0xff, 0xff, 0x01})
assert_len_encoding(t, 0xffff, []byte{0xff, 0xff, 0x03})
}

func assert_good_deserialized_value(t *testing.T, value uint16, buf []byte) {
decoded, _, err := DecodeCompactU16(buf)
require.NoError(t, err)
require.Equal(t, int(value), decoded)
{
// now try with a reader
reader := bytes.NewReader(buf)
out, _ := DecodeCompactU16LengthFromByteReader(reader)
require.Equal(t, int(value), out)
}
}

func assert_bad_deserialized_value(t *testing.T, buf []byte) {
_, _, err := DecodeCompactU16(buf)
require.Error(t, err, "expected an error for bytes: %v", buf)
{
// now try with a reader
reader := bytes.NewReader(buf)
_, err := DecodeCompactU16LengthFromByteReader(reader)
require.Error(t, err, "expected an error for bytes: %v", buf)
}
}

func TestDeserialize(t *testing.T) {
assert_good_deserialized_value(t, 0x0000, []byte{0x00})
assert_good_deserialized_value(t, 0x007f, []byte{0x7f})
assert_good_deserialized_value(t, 0x0080, []byte{0x80, 0x01})
assert_good_deserialized_value(t, 0x00ff, []byte{0xff, 0x01})
assert_good_deserialized_value(t, 0x0100, []byte{0x80, 0x02})
assert_good_deserialized_value(t, 0x07ff, []byte{0xff, 0x0f})
assert_good_deserialized_value(t, 0x3fff, []byte{0xff, 0x7f})
assert_good_deserialized_value(t, 0x4000, []byte{0x80, 0x80, 0x01})
assert_good_deserialized_value(t, 0xffff, []byte{0xff, 0xff, 0x03})

// aliases
// 0x0000
assert_bad_deserialized_value(t, []byte{0x80, 0x00})
assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x00})
// 0x007f
assert_bad_deserialized_value(t, []byte{0xff, 0x00})
assert_bad_deserialized_value(t, []byte{0xff, 0x80, 0x00})
// 0x0080
assert_bad_deserialized_value(t, []byte{0x80, 0x81, 0x00})
// 0x00ff
assert_bad_deserialized_value(t, []byte{0xff, 0x81, 0x00})
// 0x0100
assert_bad_deserialized_value(t, []byte{0x80, 0x82, 0x00})
// 0x07ff
assert_bad_deserialized_value(t, []byte{0xff, 0x8f, 0x00})
// 0x3fff
assert_bad_deserialized_value(t, []byte{0xff, 0xff, 0x00})

// too short
assert_bad_deserialized_value(t, []byte{})
assert_bad_deserialized_value(t, []byte{0x80})

// too long
assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x80, 0x00})

// too large
// 0x0001_0000
assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x04})
// 0x0001_8000
assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x06})
}

0 comments on commit 79f49c5

Please sign in to comment.