Skip to content

Commit

Permalink
Fix optional and pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
gagliardetto committed Sep 10, 2021
1 parent 2e9d59e commit 54a9b30
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 4 deletions.
89 changes: 89 additions & 0 deletions borsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,95 @@ import (
"github.com/stretchr/testify/require"
)

type OptionalPointerFields struct {
Good uint8
Arr *Arr `bin:"optional"`
}

type Arr []string

func TestOptionWithPointer(t *testing.T) {
// nil (optional not present)
{
buf := new(bytes.Buffer)
enc := NewBorshEncoder(buf)
val := OptionalPointerFields{
Good: 9,
// Will be decoded as nil pointer.
Arr: nil,
}
require.NoError(t, enc.Encode(val))
require.Equal(t,
concatByteSlices(
[]byte{9},
[]byte{0},
),
buf.Bytes())
{
dec := NewBorshDecoder(buf.Bytes())
var got OptionalPointerFields
require.NoError(t, dec.Decode(&got))
require.Equal(t, val, got)
}
}
// optional is present but has zero elements
{
buf := new(bytes.Buffer)
enc := NewBorshEncoder(buf)
val := OptionalPointerFields{
Good: 9,
// Will be decoded as pointer to nil Arr.
Arr: &Arr{},
}
require.NoError(t, enc.Encode(val))
require.Equal(t,
concatByteSlices(
[]byte{9},
[]byte{1},
[]byte{0, 0, 0, 0},
),
buf.Bytes(),
)
{
dec := NewBorshDecoder(buf.Bytes())
var got OptionalPointerFields
require.NoError(t, dec.Decode(&got))
// an empty slice is decoded as nil.
po := (Arr)(nil)
val.Arr = &po
require.Equal(t,
val, got)
}
}
// optional is present and has elements
{
buf := new(bytes.Buffer)
enc := NewBorshEncoder(buf)
val := OptionalPointerFields{
Good: 9,
Arr: &Arr{"foo"},
}
require.NoError(t, enc.Encode(val))
require.Equal(t,
concatByteSlices(
[]byte{9},
[]byte{1},
[]byte{1, 0, 0, 0},

[]byte{3, 0, 0, 0},
[]byte("foo"),
),
buf.Bytes(),
)
{
dec := NewBorshDecoder(buf.Bytes())
var got OptionalPointerFields
require.NoError(t, dec.Decode(&got))
require.Equal(t, val, got)
}
}
}

type StructWithComplexPeculiarEnums struct {
Complex2NotSet ComplexEnumPointers
Complex2PtrNotSet *ComplexEnumPointers
Expand Down
3 changes: 2 additions & 1 deletion decoder_borsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ func (dec *Decoder) decodeBorsh(rv reflect.Value, opt *option) (err error) {
)
}

// TODO: is `rv.Kind() == reflect.Ptr` correct here???
if opt.isOptional() {
isPresent, e := dec.ReadByte()
if e != nil {
Expand All @@ -59,6 +58,8 @@ func (dec *Decoder) decodeBorsh(rv reflect.Value, opt *option) (err error) {
// we have ptr here we should not go get the element
unmarshaler, rv = indirect(rv, false)
}
// Reset optionality so it won't propagate to child types:
opt = opt.clone().setIsOptional(false)

if unmarshaler != nil {
if traceEnabled {
Expand Down
8 changes: 5 additions & 3 deletions encoder_borsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) {
}

if opt.isOptional() {
if rv.IsZero() || (rv.Kind() == reflect.Ptr && rv.IsNil()) {
if rv.IsZero() {
if traceEnabled {
zlog.Debug("encode: skipping optional value with", zap.Stringer("type", rv.Kind()))
}
Expand All @@ -32,6 +32,8 @@ func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) {
}
e.WriteBool(true)
}
// Reset optionality so it won't propagate to child types:
opt = opt.clone().setIsOptional(false)

if isZero(rv) {
return nil
Expand Down Expand Up @@ -78,8 +80,8 @@ func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) {
return e.WriteBool(rv.Bool())
case reflect.Ptr:
if rv.IsNil() {
// el := reflect.New(rv.Type().Elem()).Elem()
// return e.encodeBorsh(el, nil)
el := reflect.New(rv.Type().Elem()).Elem()
return e.encodeBorsh(el, nil)
} else {
return e.encodeBorsh(rv.Elem(), nil)
}
Expand Down

0 comments on commit 54a9b30

Please sign in to comment.