From c3e479f3c5693920d5b4399e9ee76035291f56f3 Mon Sep 17 00:00:00 2001 From: gagliardetto Date: Sat, 22 Oct 2022 18:59:21 +0200 Subject: [PATCH] Refactor borsh optionality and tags --- decoder.go | 32 +++++++++++++++++++++++++- decoder_bin.go | 8 +++---- decoder_borsh.go | 34 ++++++++++++++++++++++------ decoder_compact-u16.go | 8 +++---- encoder.go | 18 +++++++++++++++ encoder_bin.go | 8 +++---- encoder_borsh.go | 29 ++++++++++++++++++------ encoder_compact-u16.go | 8 +++---- parse_test.go | 11 +++++---- options.go => tags-options.go | 42 +++++++++++++++++++++++------------ parse.go => tags-parser.go | 22 ++++++++++++++---- variant.go | 8 ++++--- 12 files changed, 170 insertions(+), 58 deletions(-) rename options.go => tags-options.go (71%) rename parse.go => tags-parser.go (80%) diff --git a/decoder.go b/decoder.go index 983438a..29e4500 100644 --- a/decoder.go +++ b/decoder.go @@ -345,6 +345,10 @@ func (dec *Decoder) ReadTypeID() (out TypeID, err error) { return TypeIDFromBytes(discriminator), nil } +func (dec *Decoder) ReadDiscriminator() (out TypeID, err error) { + return dec.ReadTypeID() +} + func (dec *Decoder) Peek(n int) (out []byte, err error) { if n < 0 { err = fmt.Errorf("n not valid: %d", n) @@ -373,6 +377,33 @@ func (dec *Decoder) ReadCompactU16() (out int, err error) { return } +func (dec *Decoder) ReadOption() (out bool, err error) { + b, err := dec.ReadByte() + if err != nil { + return false, fmt.Errorf("decode: read option, %w", err) + } + out = b != 0 + if traceEnabled { + zlog.Debug("decode: read option", zap.Bool("val", out)) + } + return +} + +func (dec *Decoder) ReadCOption() (out bool, err error) { + b, err := dec.ReadUint32(LE) + if err != nil { + return false, fmt.Errorf("decode: read c-option, %w", err) + } + if b > 1 { + return false, fmt.Errorf("decode: read c-option, invalid value: %d", b) + } + out = b != 0 + if traceEnabled { + zlog.Debug("decode: read c-option", zap.Bool("val", out)) + } + return +} + func (dec *Decoder) ReadByte() (out byte, err error) { if dec.Remaining() < TypeSize.Byte { err = fmt.Errorf("required [1] byte, remaining [%d]", dec.Remaining()) @@ -394,7 +425,6 @@ func (dec *Decoder) ReadBool() (out bool, err error) { } b, err := dec.ReadByte() - if err != nil { err = fmt.Errorf("readBool, %s", err) } diff --git a/decoder_bin.go b/decoder_bin.go index 14e6108..ca58676 100644 --- a/decoder_bin.go +++ b/decoder_bin.go @@ -47,7 +47,7 @@ func (dec *Decoder) decodeBin(rv reflect.Value, opt *option) (err error) { } dec.currentFieldOpt = opt - unmarshaler, rv := indirect(rv, opt.isOptional()) + unmarshaler, rv := indirect(rv, opt.is_Optional()) if traceEnabled { zlog.Debug("decode: type", @@ -57,7 +57,7 @@ func (dec *Decoder) decodeBin(rv reflect.Value, opt *option) (err error) { ) } - if opt.isOptional() { + if opt.is_Optional() { isPresent, e := dec.ReadUint32(binary.LittleEndian) if e != nil { err = fmt.Errorf("decode: %s isPresent, %s", rv.Type().String(), e) @@ -316,8 +316,8 @@ func (dec *Decoder) decodeStructBin(rt reflect.Type, rv reflect.Value) (err erro } option := &option{ - OptionalField: fieldTag.Optional, - Order: fieldTag.Order, + is_OptionalField: fieldTag.Option, + Order: fieldTag.Order, } if s, ok := sizeOfMap[structField.Name]; ok { diff --git a/decoder_borsh.go b/decoder_borsh.go index 81927ab..bcac383 100644 --- a/decoder_borsh.go +++ b/decoder_borsh.go @@ -47,7 +47,7 @@ func (dec *Decoder) decodeBorsh(rv reflect.Value, opt *option) (err error) { } dec.currentFieldOpt = opt - unmarshaler, rv := indirect(rv, opt.isOptional()) + unmarshaler, rv := indirect(rv, opt.is_Optional() || opt.is_COptional()) if traceEnabled { zlog.Debug("decode: type", @@ -57,14 +57,33 @@ func (dec *Decoder) decodeBorsh(rv reflect.Value, opt *option) (err error) { ) } - if opt.isOptional() { - isPresent, e := dec.ReadByte() + if opt.is_Optional() { + isPresent, e := dec.ReadOption() if e != nil { err = fmt.Errorf("decode: %t isPresent, %s", rv.Type(), e) return } - if isPresent == 0 { + if !isPresent { + if traceEnabled { + zlog.Debug("decode: skipping optional value", zap.Stringer("type", rv.Kind())) + } + + rv.Set(reflect.Zero(rv.Type())) + return + } + + // we have ptr here we should not go get the element + unmarshaler, rv = indirect(rv, false) + } + if opt.is_COptional() { + isPresent, e := dec.ReadCOption() + if e != nil { + err = fmt.Errorf("decode: %t isPresent, %s", rv.Type(), e) + return + } + + if !isPresent { if traceEnabled { zlog.Debug("decode: skipping optional value", zap.Stringer("type", rv.Kind())) } @@ -77,7 +96,7 @@ func (dec *Decoder) decodeBorsh(rv reflect.Value, opt *option) (err error) { unmarshaler, rv = indirect(rv, false) } // Reset optionality so it won't propagate to child types: - opt = opt.clone().setIsOptional(false) + opt = opt.clone().set_Optional(false).set_COptional(false) if unmarshaler != nil { if traceEnabled { @@ -371,8 +390,9 @@ func (dec *Decoder) decodeStructBorsh(rt reflect.Type, rv reflect.Value) (err er } option := &option{ - OptionalField: fieldTag.Optional, - Order: fieldTag.Order, + is_OptionalField: fieldTag.Option, + is_COptionalField: fieldTag.COption, + Order: fieldTag.Order, } if s, ok := sizeOfMap[structField.Name]; ok { diff --git a/decoder_compact-u16.go b/decoder_compact-u16.go index a5f3d38..13396dd 100644 --- a/decoder_compact-u16.go +++ b/decoder_compact-u16.go @@ -46,7 +46,7 @@ func (dec *Decoder) decodeCompactU16(rv reflect.Value, opt *option) (err error) } dec.currentFieldOpt = opt - unmarshaler, rv := indirect(rv, opt.isOptional()) + unmarshaler, rv := indirect(rv, opt.is_Optional()) if traceEnabled { zlog.Debug("decode: type", @@ -56,7 +56,7 @@ func (dec *Decoder) decodeCompactU16(rv reflect.Value, opt *option) (err error) ) } - if opt.isOptional() { + if opt.is_Optional() { isPresent, e := dec.ReadByte() if e != nil { err = fmt.Errorf("decode: %t isPresent, %s", rv.Type(), e) @@ -315,8 +315,8 @@ func (dec *Decoder) decodeStructCompactU16(rt reflect.Type, rv reflect.Value) (e } option := &option{ - OptionalField: fieldTag.Optional, - Order: fieldTag.Order, + is_OptionalField: fieldTag.Option, + Order: fieldTag.Order, } if s, ok := sizeOfMap[structField.Name]; ok { diff --git a/encoder.go b/encoder.go index fa2f939..c5f6427 100644 --- a/encoder.go +++ b/encoder.go @@ -166,6 +166,24 @@ func (e *Encoder) WriteByte(b byte) (err error) { return e.toWriter([]byte{b}) } +func (e *Encoder) WriteOption(b bool) (err error) { + if traceEnabled { + zlog.Debug("encode: write option", zap.Bool("val", b)) + } + return e.WriteBool(b) +} + +func (e *Encoder) WriteCOption(b bool) (err error) { + if traceEnabled { + zlog.Debug("encode: write c-option", zap.Bool("val", b)) + } + var num uint32 + if b { + num = 1 + } + return e.WriteUint32(num, LE) +} + func (e *Encoder) WriteBool(b bool) (err error) { if traceEnabled { zlog.Debug("encode: write bool", zap.Bool("val", b)) diff --git a/encoder_bin.go b/encoder_bin.go index 585b41e..4ba95ec 100644 --- a/encoder_bin.go +++ b/encoder_bin.go @@ -38,7 +38,7 @@ func (e *Encoder) encodeBin(rv reflect.Value, opt *option) (err error) { ) } - if opt.isOptional() { + if opt.is_Optional() { if rv.IsZero() { if traceEnabled { zlog.Debug("encode: skipping optional value with", zap.Stringer("type", rv.Kind())) @@ -50,7 +50,7 @@ func (e *Encoder) encodeBin(rv reflect.Value, opt *option) (err error) { return err } // The optionality has been used; stop its propagation: - opt.setIsOptional(false) + opt.set_Optional(false) } if isZero(rv) { @@ -237,8 +237,8 @@ func (e *Encoder) encodeStructBin(rt reflect.Type, rv reflect.Value) (err error) } option := &option{ - OptionalField: fieldTag.Optional, - Order: fieldTag.Order, + is_OptionalField: fieldTag.Option, + Order: fieldTag.Order, } if s, ok := sizeOfMap[structField.Name]; ok { diff --git a/encoder_borsh.go b/encoder_borsh.go index 75ebf59..93b8a12 100644 --- a/encoder_borsh.go +++ b/encoder_borsh.go @@ -76,22 +76,36 @@ func (e *Encoder) encodeBorsh(rv reflect.Value, opt *option) (err error) { ) } - if opt.isOptional() { + if opt.is_Optional() { if rv.IsZero() { if traceEnabled { zlog.Debug("encode: skipping optional value with", zap.Stringer("type", rv.Kind())) } - return e.WriteBool(false) + return e.WriteOption(false) } - err := e.WriteBool(true) + err := e.WriteOption(true) if err != nil { return err } // The optionality has been used; stop its propagation: - opt.setIsOptional(false) + opt.set_Optional(false) + } + if opt.is_COptional() { + if rv.IsZero() { + if traceEnabled { + zlog.Debug("encode: skipping optional value with", zap.Stringer("type", rv.Kind())) + } + return e.WriteCOption(false) + } + err := e.WriteCOption(true) + if err != nil { + return err + } + // The optionality has been used; stop its propagation: + opt.set_COptional(false) } // Reset optionality so it won't propagate to child types: - opt = opt.clone().setIsOptional(false) + opt = opt.clone().set_Optional(false).set_COptional(false) if isZero(rv) { return nil @@ -327,8 +341,9 @@ func (e *Encoder) encodeStructBorsh(rt reflect.Type, rv reflect.Value) (err erro } option := &option{ - OptionalField: fieldTag.Optional, - Order: fieldTag.Order, + is_OptionalField: fieldTag.Option, + is_COptionalField: fieldTag.COption, + Order: fieldTag.Order, } if s, ok := sizeOfMap[structField.Name]; ok { diff --git a/encoder_compact-u16.go b/encoder_compact-u16.go index 488dda6..2616177 100644 --- a/encoder_compact-u16.go +++ b/encoder_compact-u16.go @@ -37,7 +37,7 @@ func (e *Encoder) encodeCompactU16(rv reflect.Value, opt *option) (err error) { ) } - if opt.isOptional() { + if opt.is_Optional() { if rv.IsZero() { if traceEnabled { zlog.Debug("encode: skipping optional value with", zap.Stringer("type", rv.Kind())) @@ -49,7 +49,7 @@ func (e *Encoder) encodeCompactU16(rv reflect.Value, opt *option) (err error) { return err } // The optionality has been used; stop its propagation: - opt.setIsOptional(false) + opt.set_Optional(false) } if isZero(rv) { @@ -235,8 +235,8 @@ func (e *Encoder) encodeStructCompactU16(rt reflect.Type, rv reflect.Value) (err } option := &option{ - OptionalField: fieldTag.Optional, - Order: fieldTag.Order, + is_OptionalField: fieldTag.Option, + Order: fieldTag.Order, } if s, ok := sizeOfMap[structField.Name]; ok { diff --git a/parse_test.go b/parse_test.go index f13d73f..1c7c114 100644 --- a/parse_test.go +++ b/parse_test.go @@ -55,17 +55,17 @@ func Test_parseFieldTag(t *testing.T) { name: "with a optional", tag: `bin:"optional"`, expectValue: &fieldTag{ - Order: binary.LittleEndian, - Optional: true, + Order: binary.LittleEndian, + Option: true, }, }, { name: "with a optional and size of", tag: `bin:"optional sizeof=Nodes"`, expectValue: &fieldTag{ - Order: binary.LittleEndian, - Optional: true, - SizeOf: "Nodes", + Order: binary.LittleEndian, + Option: true, + SizeOf: "Nodes", }, }, } @@ -75,5 +75,4 @@ func Test_parseFieldTag(t *testing.T) { assert.Equal(t, test.expectValue, parseFieldTag(reflect.StructTag(test.tag))) }) } - } diff --git a/options.go b/tags-options.go similarity index 71% rename from options.go rename to tags-options.go index 42c310a..e0da9e6 100644 --- a/options.go +++ b/tags-options.go @@ -20,34 +20,42 @@ package bin import "encoding/binary" type option struct { - OptionalField bool - SizeOfSlice *int - Order binary.ByteOrder + is_OptionalField bool + is_COptionalField bool + SizeOfSlice *int + Order binary.ByteOrder } -var LE binary.ByteOrder = binary.LittleEndian -var BE binary.ByteOrder = binary.BigEndian +var ( + LE binary.ByteOrder = binary.LittleEndian + BE binary.ByteOrder = binary.BigEndian +) var defaultByteOrder = binary.LittleEndian func newDefaultOption() *option { return &option{ - OptionalField: false, - Order: defaultByteOrder, + is_OptionalField: false, + Order: defaultByteOrder, } } func (o *option) clone() *option { out := &option{ - OptionalField: o.OptionalField, - SizeOfSlice: o.SizeOfSlice, - Order: o.Order, + is_OptionalField: o.is_OptionalField, + is_COptionalField: o.is_COptionalField, + SizeOfSlice: o.SizeOfSlice, + Order: o.Order, } return out } -func (o *option) isOptional() bool { - return o.OptionalField +func (o *option) is_Optional() bool { + return o.is_OptionalField +} + +func (o *option) is_COptional() bool { + return o.is_COptionalField } func (o *option) hasSizeOfSlice() bool { @@ -62,8 +70,14 @@ func (o *option) setSizeOfSlice(size int) *option { o.SizeOfSlice = &size return o } -func (o *option) setIsOptional(isOptional bool) *option { - o.OptionalField = isOptional + +func (o *option) set_Optional(isOptional bool) *option { + o.is_OptionalField = isOptional + return o +} + +func (o *option) set_COptional(isCOptional bool) *option { + o.is_COptionalField = isCOptional return o } diff --git a/parse.go b/tags-parser.go similarity index 80% rename from parse.go rename to tags-parser.go index 8e7ab38..39f8e3a 100644 --- a/parse.go +++ b/tags-parser.go @@ -27,12 +27,22 @@ type fieldTag struct { SizeOf string Skip bool Order binary.ByteOrder - Optional bool + Option bool + COption bool BinaryExtension bool IsBorshEnum bool } +func isIn(s string, candidates ...string) bool { + for _, c := range candidates { + if s == c { + return true + } + } + return false +} + func parseFieldTag(tag reflect.StructTag) *fieldTag { t := &fieldTag{ Order: defaultByteOrder, @@ -46,12 +56,16 @@ func parseFieldTag(tag reflect.StructTag) *fieldTag { t.Order = binary.BigEndian } else if s == "little" { t.Order = binary.LittleEndian - } else if s == "optional" { - t.Optional = true + } else if isIn(s, "optional", "option") { + t.Option = true + } else if isIn(s, "coption") { + t.COption = true } else if s == "binary_extension" { t.BinaryExtension = true - } else if s == "-" { + } else if isIn(s, "-", "skip") { t.Skip = true + } else if isIn(s, "enum") { + t.IsBorshEnum = true } } diff --git a/variant.go b/variant.go index 7cbe40b..386a655 100644 --- a/variant.go +++ b/variant.go @@ -260,8 +260,10 @@ func (d *VariantDefinition) TypeID(name string) TypeID { return id } -type VariantImplFactory = func() interface{} -type OnVariant = func(impl interface{}) error +type ( + VariantImplFactory = func() interface{} + OnVariant = func(impl interface{}) error +) type BaseVariant struct { TypeID TypeID @@ -279,7 +281,7 @@ func (a *BaseVariant) Obtain(def *VariantDefinition) (typeID TypeID, typeName st return a.TypeID, def.typeIDToName[a.TypeID], a.Impl } -func (a *BaseVariant) MarshalJSON(def *VariantDefinition) ([]byte, error) { +func (a BaseVariant) MarshalJSON(def *VariantDefinition) ([]byte, error) { typeName, found := def.typeIDToName[a.TypeID] if !found { return nil, fmt.Errorf("type %d is not know by variant definition", a.TypeID)