Skip to content

Commit

Permalink
types: Add EncodePtrCast and DecodePtrCast
Browse files Browse the repository at this point in the history
  • Loading branch information
lukechampine authored and ChrisSchinnerl committed Nov 20, 2024
1 parent 51068a3 commit 4a72a16
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
28 changes: 28 additions & 0 deletions types/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ func EncodePtr[T any, P interface {
}
}

// EncodePtrCast encodes a pointer to an object by casting it to V.
func EncodePtrCast[V interface {
Cast() T
EncoderTo
}, T any](e *Encoder, p *T) {
e.WriteBool(p != nil)
if p != nil {
vp := *(*V)(unsafe.Pointer(p))
vp.EncodeTo(e)
}
}

// EncodeSlice encodes a slice of objects that implement EncoderTo.
func EncodeSlice[T EncoderTo](e *Encoder, s []T) {
e.WriteUint64(uint64(len(s)))
Expand Down Expand Up @@ -254,6 +266,22 @@ func DecodePtr[T any, TP interface {
}
}

// DecodePtrCast decodes a pointer to an object by casting it to V.
func DecodePtrCast[T interface {
Cast() V
}, TP interface {
*T
DecoderFrom
}, V any](d *Decoder, p **V) {
tp := (**T)(unsafe.Pointer(p))
if d.ReadBool() {
*tp = new(T)
TP(*tp).DecodeFrom(d)
} else {
*tp = nil
}
}

// DecodeSlice decodes a length-prefixed slice of type T, containing values read
// from the decoder.
func DecodeSlice[T any, DF interface {
Expand Down
20 changes: 20 additions & 0 deletions types/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ import (
"lukechampine.com/frand"
)

func TestEncodePtrCast(t *testing.T) {
var buf bytes.Buffer
e := types.NewEncoder(&buf)
c := types.Siacoins(1)
types.EncodePtrCast[types.V1Currency](e, &c)
types.EncodePtrCast[types.V2Currency](e, &c)
types.EncodePtrCast[types.V2Currency](e, nil)
e.Flush()
var c1, c2, c3 *types.Currency
d := types.NewBufDecoder(buf.Bytes())
types.DecodePtrCast[types.V1Currency](d, &c1)
types.DecodePtrCast[types.V2Currency](d, &c2)
types.DecodePtrCast[types.V2Currency](d, &c3)
if err := d.Err(); err != nil {
t.Fatal(err)
} else if *c1 != c || *c2 != c || c3 != nil {
t.Fatal("mismatch:", c1, c2, c3)
}
}

func TestEncodeSlice(t *testing.T) {
txns := multiproofTxns(10, 10)
var buf bytes.Buffer
Expand Down

0 comments on commit 4a72a16

Please sign in to comment.