Skip to content

Commit

Permalink
all: generate unsafe 2D array calls to avoid type chaos in generated …
Browse files Browse the repository at this point in the history
…code
  • Loading branch information
karalabe committed Jul 10, 2024
1 parent 61784a4 commit db92a3f
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 63 deletions.
21 changes: 7 additions & 14 deletions cmd/sszgen/opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (p *parseContext) resolveBitlistOpset(tags *sizeTag) (opset, error) {
}, nil
}

func (p *parseContext) resolveArrayOpset(typ types.Type, name string, size int, tags *sizeTag) (opset, error) {
func (p *parseContext) resolveArrayOpset(typ types.Type, size int, tags *sizeTag) (opset, error) {
switch typ := typ.(type) {
case *types.Basic:
// Sanity check a few tag constraints relevant for all arrays of basic types
Expand Down Expand Up @@ -157,21 +157,17 @@ func (p *parseContext) resolveArrayOpset(typ types.Type, name string, size int,
return nil, fmt.Errorf("unsupported array item basic type: %s", typ)
}
case *types.Array:
return p.resolveArrayOfArrayOpset(typ.Elem(), name, size, typ.String(), int(typ.Len()), tags)
return p.resolveArrayOfArrayOpset(typ.Elem(), size, int(typ.Len()), tags)

case *types.Named:
// For named arrays, we need to pass the name too for the generics
if t, ok := typ.Underlying().(*types.Array); ok {
return p.resolveArrayOfArrayOpset(t.Elem(), name, size, typ.Obj().Name(), int(t.Len()), tags)
}
return p.resolveArrayOpset(typ.Underlying(), name, size, tags)
return p.resolveArrayOpset(typ.Underlying(), size, tags)

default:
return nil, fmt.Errorf("unsupported array item type: %s", typ)
}
}

func (p *parseContext) resolveArrayOfArrayOpset(typ types.Type, outerType string, outerSize int, innerType string, innerSize int, tags *sizeTag) (opset, error) {
func (p *parseContext) resolveArrayOfArrayOpset(typ types.Type, outerSize, innerSize int, tags *sizeTag) (opset, error) {
switch typ := typ.(type) {
case *types.Basic:
// Sanity check a few tag constraints relevant for all arrays of basic types
Expand All @@ -189,13 +185,10 @@ func (p *parseContext) resolveArrayOfArrayOpset(typ types.Type, outerType string
return nil, fmt.Errorf("array of array of byte basic type tag conflict: field is [%d, %d] bytes, tag wants %v bytes", outerSize, innerSize, tags.size)
}
}
if outerType == "" {
outerType = fmt.Sprintf("[%d]%s", outerSize, innerType)
}
return &opsetStatic{
fmt.Sprintf("DefineArrayOfStaticBytes[%s, %s]({{.Codec}}, &{{.Field}})", outerType, innerType),
fmt.Sprintf("EncodeArrayOfStaticBytes[%s, %s]({{.Codec}}, &{{.Field}})", outerType, innerType),
fmt.Sprintf("DecodeArrayOfStaticBytes[%s, %s]({{.Codec}}, &{{.Field}})", outerType, innerType),
"DefineUnsafeArrayOfStaticBytes({{.Codec}}, {{.Field}}[:])",
"EncodeUnsafeArrayOfStaticBytes({{.Codec}}, {{.Field}}[:])",
"DecodeUnsafeArrayOfStaticBytes({{.Codec}}, {{.Field}}[:])",
[]int{outerSize, innerSize},
}, nil
default:
Expand Down
5 changes: 1 addition & 4 deletions cmd/sszgen/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,13 @@ func (p *parseContext) resolveOpset(typ types.Type, tags *sizeTag) (opset, error
if isBitlist(typ) {
return p.resolveBitlistOpset(tags)
}
if nt, ok := typ.Underlying().(*types.Array); ok {
return p.resolveArrayOpset(nt.Elem(), t.Obj().Name(), int(nt.Len()), tags)
}
return p.resolveOpset(t.Underlying(), tags)

case *types.Basic:
return p.resolveBasicOpset(t, tags)

case *types.Array:
return p.resolveArrayOpset(t.Elem(), "", int(t.Len()), tags)
return p.resolveArrayOpset(t.Elem(), int(t.Len()), tags)

case *types.Slice:
return p.resolveSliceOpset(t.Elem(), tags)
Expand Down
14 changes: 13 additions & 1 deletion codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,21 @@ func DefineArrayOfStaticBytes[T commonBytesArrayLengths[U], U commonBytesLengths
DecodeArrayOfStaticBytes[T, U](c.dec, blobs)
}

// DefineUnsafeArrayOfStaticBytes defines the next field as a static array of
// static binary blobs. This method operates on plain slices of byte arrays and
// will crash if provided a slice of a non-array. Its purpose is to get around
// Go's generics limitations in generated code (use DefineArrayOfStaticBytes).
func DefineUnsafeArrayOfStaticBytes[T commonBytesLengths](c *Codec, blobs []T) {
if c.enc != nil {
EncodeUnsafeArrayOfStaticBytes(c.enc, blobs)
return
}
DecodeUnsafeArrayOfStaticBytes(c.dec, blobs)
}

// DefineCheckedArrayOfStaticBytes defines the next field as a static array of
// static binary blobs. This method can be used for plain slices of byte arrays,
// which is more expensive since it needs runtime size validation.
// which is more expensive since it needs runtime size validation.
func DefineCheckedArrayOfStaticBytes[T commonBytesLengths](c *Codec, blobs *[]T, size uint64) {
if c.enc != nil {
EncodeCheckedArrayOfStaticBytes(c.enc, *blobs)
Expand Down
23 changes: 15 additions & 8 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,29 +426,36 @@ func DecodeSliceOfUint64sContent[T ~uint64](dec *Decoder, ns *[]T, maxItems uint

// DecodeArrayOfStaticBytes parses a static array of static binary blobs.
func DecodeArrayOfStaticBytes[T commonBytesArrayLengths[U], U commonBytesLengths](dec *Decoder, blobs *T) {
// The code below should have used `(*blobs)[:]`, alas Go's generics compiler
// is missing that (i.e. a bug): https://github.com/golang/go/issues/51740
DecodeUnsafeArrayOfStaticBytes(dec, unsafe.Slice(&(*blobs)[0], len(*blobs)))
}

// DecodeUnsafeArrayOfStaticBytes parses a static array of static binary blobs.
func DecodeUnsafeArrayOfStaticBytes[T commonBytesLengths](dec *Decoder, blobs []T) {
if dec.err != nil {
return
}
if dec.inReader != nil {
for i := 0; i < len(*blobs); i++ {
for i := 0; i < len(blobs); i++ {
// The code below should have used `(*blobs)[i][:]`, alas Go's generics compiler
// is missing that (i.e. a bug): https://github.com/golang/go/issues/51740
_, dec.err = io.ReadFull(dec.inReader, unsafe.Slice(&(*blobs)[i][0], len((*blobs)[i])))
_, dec.err = io.ReadFull(dec.inReader, unsafe.Slice(&(blobs)[i][0], len((blobs)[i])))
if dec.err != nil {
return
}
dec.inRead += uint32(len((*blobs)[i]))
dec.inRead += uint32(len((blobs)[i]))
}
} else {
for i := 0; i < len(*blobs); i++ {
if len(dec.inBuffer) < len((*blobs)[i]) {
for i := 0; i < len(blobs); i++ {
if len(dec.inBuffer) < len((blobs)[i]) {
dec.err = io.ErrUnexpectedEOF
return
}
// The code below should have used `*blobs[i][:]`, alas Go's generics compiler
// The code below should have used `blobs[i][:]`, alas Go's generics compiler
// is missing that (i.e. a bug): https://github.com/golang/go/issues/51740
copy(unsafe.Slice(&(*blobs)[i][0], len((*blobs)[i])), dec.inBuffer)
dec.inBuffer = dec.inBuffer[len((*blobs)[i]):]
copy(unsafe.Slice(&(blobs)[i][0], len((blobs)[i])), dec.inBuffer)
dec.inBuffer = dec.inBuffer[len((blobs)[i]):]
}
}
}
Expand Down
22 changes: 15 additions & 7 deletions encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,24 +342,32 @@ func EncodeSliceOfUint64sContent[T ~uint64](enc *Encoder, ns []T) {
// from escaping to the heap (and incurring an allocation) when passing it to
// the output stream.
func EncodeArrayOfStaticBytes[T commonBytesArrayLengths[U], U commonBytesLengths](enc *Encoder, blobs *T) {
// The code below should have used `(*blobs)[:]`, alas Go's generics compiler
// is missing that (i.e. a bug): https://github.com/golang/go/issues/51740
EncodeUnsafeArrayOfStaticBytes(enc, unsafe.Slice(&(*blobs)[0], len(*blobs)))
}

// EncodeUnsafeArrayOfStaticBytes serializes a static array of static binary
// blobs.
func EncodeUnsafeArrayOfStaticBytes[T commonBytesLengths](enc *Encoder, blobs []T) {
// Internally this method is essentially calling EncodeStaticBytes on all
// the blobs in a loop. Practically, we've inlined that call to make things
// a *lot* faster.
if enc.outWriter != nil {
for i := 0; i < len(*blobs); i++ { // don't range loop, T might be an array, copy is expensive
for i := 0; i < len(blobs); i++ { // don't range loop, T might be an array, copy is expensive
if enc.err != nil {
return
}
// The code below should have used `(*blobs)[i][:]`, alas Go's generics compiler
// The code below should have used `blobs[i][:]`, alas Go's generics compiler
// is missing that (i.e. a bug): https://github.com/golang/go/issues/51740
_, enc.err = enc.outWriter.Write(unsafe.Slice(&((*blobs)[i])[0], len((*blobs)[i])))
_, enc.err = enc.outWriter.Write(unsafe.Slice(&blobs[i][0], len(blobs[i])))
}
} else {
for i := 0; i < len(*blobs); i++ { // don't range loop, T might be an array, copy is expensive
// The code below should have used `(*blobs)[i][:]`, alas Go's generics compiler
for i := 0; i < len(blobs); i++ { // don't range loop, T might be an array, copy is expensive
// The code below should have used `blobs[i][:]`, alas Go's generics compiler
// is missing that (i.e. a bug): https://github.com/golang/go/issues/51740
copy(enc.outBuffer, unsafe.Slice(&((*blobs)[i])[0], len((*blobs)[i])))
enc.outBuffer = enc.outBuffer[len((*blobs)[i]):]
copy(enc.outBuffer, unsafe.Slice(&blobs[i][0], len(blobs[i])))
enc.outBuffer = enc.outBuffer[len(blobs[i]):]
}
}
}
Expand Down
42 changes: 21 additions & 21 deletions tests/testtypes/consensus-spec-tests/gen_beacon_state_ssz.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testtypes/consensus-spec-tests/gen_deposit_ssz.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit db92a3f

Please sign in to comment.