Skip to content

Commit

Permalink
cmd/sszgen, tests/testtypes: continue crunching through the types for…
Browse files Browse the repository at this point in the history
… the generator
  • Loading branch information
karalabe committed Jul 8, 2024
1 parent 85ea6e6 commit 67adb6c
Show file tree
Hide file tree
Showing 37 changed files with 712 additions and 418 deletions.
139 changes: 109 additions & 30 deletions cmd/sszgen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"fmt"
"go/types"
"html/template"
"math"
"sort"
)
Expand Down Expand Up @@ -94,6 +95,22 @@ func (ctx *genContext) reset() {
ctx.topType = true
}

func generate(ctx *genContext, typ *sszContainer) ([]byte, error) {
var codes [][]byte
for _, fn := range []func(ctx *genContext, typ *sszContainer) ([]byte, error){
generateSizeSSZ,
generateDefineSSZ,
} {
code, err := fn(ctx, typ)
if err != nil {
return nil, err
}
codes = append(codes, code)
}
fmt.Println(string(bytes.Join(codes, []byte("\n"))))
return bytes.Join(codes, []byte("\n")), nil
}

func generateSizeSSZ(ctx *genContext, typ *sszContainer) ([]byte, error) {
var b bytes.Buffer
ctx.reset()
Expand All @@ -104,7 +121,7 @@ func generateSizeSSZ(ctx *genContext, typ *sszContainer) ([]byte, error) {
// time or if runtime resolutions are needed
var runtime bool
for i := range typ.opsets {
if typ.opsets[i].(*opsetStatic).bytes == 0 {
if typ.opsets[i].(*opsetStatic).bytes == nil {
runtime = true
break
}
Expand All @@ -115,8 +132,12 @@ func generateSizeSSZ(ctx *genContext, typ *sszContainer) ([]byte, error) {
fmt.Fprintf(&b, "// Cached static size computed on package init.\n")
fmt.Fprintf(&b, "var staticSizeCache%s = ", typ.named.Obj().Name())
for i := range typ.opsets {
if bytes := typ.opsets[i].(*opsetStatic).bytes; bytes > 0 {
fmt.Fprintf(&b, "%d", bytes)
if bytes := typ.opsets[i].(*opsetStatic).bytes; bytes != nil {
if len(bytes) == 1 {
fmt.Fprintf(&b, "%d", bytes[0])
} else {
fmt.Fprintf(&b, "%d*%d", bytes[0], bytes[1])
}
} else {
typ := typ.types[i].(*types.Pointer).Elem().(*types.Named)
pkg := typ.Obj().Pkg()
Expand All @@ -140,20 +161,25 @@ func generateSizeSSZ(ctx *genContext, typ *sszContainer) ([]byte, error) {
fmt.Fprintf(&b, "func (obj *%s) SizeSSZ() uint32 {\n", typ.named.Obj().Name())
fmt.Fprint(&b, " return ")
for i := range typ.opsets {
fmt.Fprintf(&b, "%d", typ.opsets[i].(*opsetStatic).bytes)
bytes := typ.opsets[i].(*opsetStatic).bytes
if len(bytes) == 1 {
fmt.Fprintf(&b, "%d", bytes[0])
} else {
fmt.Fprintf(&b, "%d*%d", bytes[0], bytes[1])
}
if i < len(typ.opsets)-1 {
fmt.Fprint(&b, " + ")
}
}
fmt.Fprintf(&b, "}\n")
fmt.Fprintf(&b, "\n}\n")
}
} else {
// Iterate through the fields to see if the static size can be computed
// compile time or if runtime resolutions are needed even for statics.
var runtime bool
for i := range typ.opsets {
if typ, ok := typ.opsets[i].(*opsetStatic); ok {
if typ.bytes == 0 {
if typ.bytes == nil {
runtime = true
break
}
Expand All @@ -167,8 +193,12 @@ func generateSizeSSZ(ctx *genContext, typ *sszContainer) ([]byte, error) {
for i := range typ.opsets {
switch t := typ.opsets[i].(type) {
case *opsetStatic:
if t.bytes > 0 {
fmt.Fprintf(&b, "%d", t.bytes)
if t.bytes != nil {
if len(t.bytes) == 1 {
fmt.Fprintf(&b, "%d", t.bytes[0])
} else {
fmt.Fprintf(&b, "%d*%d", t.bytes[0], t.bytes[1])
}
} else {
typ := typ.types[i].(*types.Pointer).Elem().(*types.Named)
pkg := typ.Obj().Pkg()
Expand All @@ -190,7 +220,37 @@ func generateSizeSSZ(ctx *genContext, typ *sszContainer) ([]byte, error) {
fmt.Fprintf(&b, "func (obj *%s) SizeSSZ(fixed bool) uint32 {\n", typ.named.Obj().Name())
fmt.Fprintf(&b, " var size = uint32(staticSizeCache%s)\n", typ.named.Obj().Name())
fmt.Fprintf(&b, " if (fixed) {\n")
fmt.Fprintf(&b, " return staticSizeCache%s\n", typ.named.Obj().Name())
fmt.Fprintf(&b, " return size\n")
fmt.Fprintf(&b, " }\n")
for i := range typ.opsets {
if _, ok := typ.opsets[i].(*opsetDynamic); ok {
fmt.Fprintf(&b, " size += obj.%s.SizeSSZ(false)\n", typ.fields[i])
}
}
fmt.Fprintf(&b, " return size\n")
fmt.Fprintf(&b, "}\n")
} else {
fmt.Fprintf(&b, "\n\n// SizeSSZ returns either the static size of the object if fixed == true, or\n// the total size otherwise.\n")
fmt.Fprintf(&b, "func (obj *%s) SizeSSZ(fixed bool) uint32 {\n", typ.named.Obj().Name())
fmt.Fprintf(&b, " var size = uint32(")
for i := range typ.opsets {
switch t := typ.opsets[i].(type) {
case *opsetStatic:
if len(t.bytes) == 1 {
fmt.Fprintf(&b, "%d", t.bytes[0])
} else {
fmt.Fprintf(&b, "%d*%d", t.bytes[0], t.bytes[1])
}
case *opsetDynamic:
fmt.Fprintf(&b, "%d", offsetBytes)
}
if i < len(typ.opsets)-1 {
fmt.Fprint(&b, " + ")
}
}
fmt.Fprintf(&b, ")\n")
fmt.Fprintf(&b, " if (fixed) {\n")
fmt.Fprintf(&b, " return size\n")
fmt.Fprintf(&b, " }\n")
for i := range typ.opsets {
if _, ok := typ.opsets[i].(*opsetDynamic); ok {
Expand All @@ -214,13 +274,18 @@ func generateDefineSSZ(ctx *genContext, typ *sszContainer) ([]byte, error) {
// Iterate through the fields names to compute some comment formatting mods
var (
maxFieldLength = 0
maxBytes = 0
maxBytes = 1
)
for i, field := range typ.fields {
maxFieldLength = max(maxFieldLength, len(field))
switch opset := typ.opsets[i].(type) {
case *opsetStatic:
maxBytes = max(maxBytes, opset.bytes)
switch len(opset.bytes) {
case 1:
maxBytes = max(maxBytes, opset.bytes[0])
case 2:
maxBytes = max(maxBytes, opset.bytes[0]*opset.bytes[1])
}
case *opsetDynamic:
maxBytes = max(maxBytes, offsetBytes) // offset size
}
Expand All @@ -240,41 +305,55 @@ func generateDefineSSZ(ctx *genContext, typ *sszContainer) ([]byte, error) {
field := typ.fields[i]
switch opset := typ.opsets[i].(type) {
case *opsetStatic:
if opset.bytes > 0 {
fmt.Fprintf(&b, "ssz.%s(codec, &obj.%s) // Field ("+indexRule+") - "+nameRule+" - %"+sizeRule+"d bytes\n", opset.define, field, i, field, opset.bytes)
} else {
call := generateCall(opset.define, "codec", "obj."+field, opset.bytes...)
switch len(opset.bytes) {
case 0:
typ := typ.types[i].(*types.Pointer).Elem().(*types.Named)
fmt.Fprintf(&b, " ssz.%s(codec, &obj.%s) // Field ("+indexRule+") - "+nameRule+" - %"+sizeRule+"s bytes (%s)\n", opset.define, field, i, field, "?", typ.Obj().Name())
fmt.Fprintf(&b, " ssz.%s // Field ("+indexRule+") - "+nameRule+" - %"+sizeRule+"s bytes (%s)\n", call, i, field, "?", typ.Obj().Name())
case 1:
fmt.Fprintf(&b, " ssz.%s // Field ("+indexRule+") - "+nameRule+" - %"+sizeRule+"d bytes\n", call, i, field, opset.bytes[0])
case 2:
fmt.Fprintf(&b, " ssz.%s // Field ("+indexRule+") - "+nameRule+" - %"+sizeRule+"d bytes\n", call, i, field, opset.bytes[0]*opset.bytes[1])
}
case *opsetDynamic:
fmt.Fprintf(&b, " ssz.%s(codec, &obj.%s) // Offset ("+indexRule+") - "+nameRule+" - %"+sizeRule+"d bytes\n", opset.defineOffset, field, i, field, offsetBytes)
call := generateCall(opset.defineOffset, "codec", "obj."+field, opset.limits...)
fmt.Fprintf(&b, " ssz.%s // Offset ("+indexRule+") - "+nameRule+" - %"+sizeRule+"d bytes\n", call, i, field, offsetBytes)
}
}
if !typ.static {
fmt.Fprint(&b, "\n // Define the dynamic data (fields)\n")
for i := 0; i < len(typ.fields); i++ {
field := typ.fields[i]
if opset, ok := (typ.opsets[i]).(*opsetDynamic); ok {
fmt.Fprintf(&b, " ssz.%s(codec, &obj.%s) // Field ("+indexRule+") - "+nameRule+" - ? bytes\n", opset.defineContent, field, i, field)
call := generateCall(opset.defineContent, "codec", "obj."+field, opset.limits...)
fmt.Fprintf(&b, " ssz.%s // Field ("+indexRule+") - "+nameRule+" - ? bytes\n", call, i, field)
}
}
}
fmt.Fprint(&b, "}\n")
return b.Bytes(), nil
}

func generate(ctx *genContext, typ *sszContainer) ([]byte, error) {
var codes [][]byte
for _, fn := range []func(ctx *genContext, typ *sszContainer) ([]byte, error){
generateSizeSSZ,
generateDefineSSZ,
} {
code, err := fn(ctx, typ)
if err != nil {
return nil, err
}
codes = append(codes, code)
// generateCall parses a Go template and fills it with the provided data. This
// could be done more optimally, but we really don't care for a code generator.
func generateCall(tmpl string, recv string, field string, limits ...int) string {
t, err := template.New("").Parse(tmpl)
if err != nil {
panic(err)
}
//fmt.Println(string(bytes.Join(codes, []byte("\n"))))
return bytes.Join(codes, []byte("\n")), nil
d := map[string]interface{}{
"Codec": recv,
"Field": field,
}
if len(limits) > 0 {
d["MaxSize"] = limits[len(limits)-1]
}
if len(limits) > 1 {
d["MaxItems"] = limits[len(limits)-2]
}
buf := new(bytes.Buffer)
if err := t.Execute(buf, d); err != nil {
panic(err)
}
return string(buf.Bytes())
}
Loading

0 comments on commit 67adb6c

Please sign in to comment.