From 10ac1554fef052af802192e0e131742bb9193fc6 Mon Sep 17 00:00:00 2001 From: Zaq? Wiedmann Date: Tue, 25 Aug 2020 15:10:02 -0700 Subject: [PATCH] WIP: support using protobuf types in query parameters Useful for passing dates --- gengokit/httptransport/httptransport.go | 50 +++++++++++++++++++--- gengokit/httptransport/templates/server.go | 4 ++ svcdef/svcdef.go | 3 ++ 3 files changed, 50 insertions(+), 7 deletions(-) diff --git a/gengokit/httptransport/httptransport.go b/gengokit/httptransport/httptransport.go index 6aea6bfb..22e86725 100644 --- a/gengokit/httptransport/httptransport.go +++ b/gengokit/httptransport/httptransport.go @@ -11,9 +11,9 @@ import ( "text/template" "unicode" - log "github.com/sirupsen/logrus" gogen "github.com/gogo/protobuf/protoc-gen-gogo/generator" "github.com/pkg/errors" + log "github.com/sirupsen/logrus" "github.com/metaverse/truss/gengokit/httptransport/templates" "github.com/metaverse/truss/svcdef" @@ -139,9 +139,9 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding { LocalName: fmt.Sprintf("%s%s", gogen.CamelCase(field.Name), gogen.CamelCase(meth.Name)), } - if field.Type.Message == nil && field.Type.Enum == nil && field.Type.Map == nil { + if field.Type.Message == nil && field.Type.Enum == nil && field.Type.Map == nil && !isProtobufType(field.Type.Name) && field.Type.Name != "time.Time" { newField.IsBaseType = true - } else { + } else if !isProtobufType(field.Type.Name) && field.Type.Name != "time.Time" { newField.GoType = "pb." + newField.GoType } @@ -165,7 +165,7 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding { } // Emit warnings for certain cases - if !newField.IsBaseType && newField.Location != "body" { + if !newField.IsBaseType && newField.Location != "body" && !isSafeNonBaseType(newField.GoType) { log.Warnf( "%s.%s is a non-base type specified to be located outside of "+ "the body. Non-base types outside the body may result in "+ @@ -173,7 +173,7 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding { meth.Name, newField.Name) } - if newField.Repeated && newField.Location == "path" { + if newField.Repeated && newField.Location == "path" && !isSafeNonBaseType(newField.GoType) { log.Warnf( "%s.%s is a repeated field specified to be in the path. "+ "Repeated fields are not supported in the path and may"+ @@ -185,6 +185,24 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding { return &nBinding } +func isProtobufType(t string) bool { + if strings.HasPrefix(t, "types.") { + return true + } + return false +} + +func isSafeNonBaseType(t string) bool { + switch t { + case "time.Time": // gogo stdtime case + return true + case "types.Timestamp", "types.Duration": + return true + default: + return false + } +} + func GenServerTemplate(exec interface{}) (string, error) { code, err := ApplyTemplate("ServerTemplate", templates.ServerTemplate, exec, TemplateFuncs) if err != nil { @@ -381,6 +399,10 @@ func createDecodeConvertFunc(f Field) (string, bool) { var {{.LocalName}} *{{.GoType}} {{.LocalName}} = &{{.GoType}}{} err = json.Unmarshal([]byte({{.LocalName}}Str), {{.LocalName}})` + singlePBTypeUnmarshalTmpl := ` +var {{.LocalName}} *{{.GoType}} +{{.LocalName}} = &{{.GoType}}{} +err = jsonpb.UnmarshalString({{.LocalName}}Str, {{.LocalName}})` // All repeated args of any type are represented as slices, and bare // assignments to a slice accept a slice as the rvalue. As a result, // LocalName will be declared as a slice, and json.Unmarshal handles @@ -399,16 +421,30 @@ if err != nil { {{- end}} err = json.Unmarshal([]byte({{.LocalName}}Str), &{{.LocalName}})` + repeatedPBUnmarshalTmpl := ` +var {{.LocalName}} {{.GoType}} +{{- if and (and .IsBaseType .Repeated) (not (Contains .GoType "[]byte"))}} +err = jsonpb.UnmarshalString({{.LocalName}}Str, &{{.LocalName}}) +if err != nil { + {{.LocalName}}Str = "[" + {{.LocalName}}Str + "]" +} +{{- end}} +err = jsonpb.UnmarshalString({{.LocalName}}Str, &{{.LocalName}})` + errorCheckingTmpl := ` if err != nil { return nil, errors.Wrapf(err, "couldn't decode {{.LocalName}} from %v", {{.LocalName}}Str) }` var preamble string - if !f.Repeated { + if !f.Repeated && !isProtobufType(f.GoType) { preamble = singleCustomTypeUnmarshalTmpl - } else { + } else if !f.Repeated && isProtobufType(f.GoType) { + preamble = singlePBTypeUnmarshalTmpl + } else if !isProtobufType(f.GoType) { preamble = repeatedUnmarshalTmpl + } else { + preamble = repeatedPBUnmarshalTmpl } jsonConvTmpl := preamble + errorCheckingTmpl code, err := ApplyTemplate("UnmarshalNonBaseType", jsonConvTmpl, f, TemplateFuncs) diff --git a/gengokit/httptransport/templates/server.go b/gengokit/httptransport/templates/server.go index c81fdba3..c3b75c73 100644 --- a/gengokit/httptransport/templates/server.go +++ b/gengokit/httptransport/templates/server.go @@ -77,9 +77,11 @@ import ( "strconv" "strings" "io" + "time" "github.com/gogo/protobuf/jsonpb" "github.com/gogo/protobuf/proto" + "github.com/gogo/protobuf/types" "context" @@ -102,6 +104,8 @@ var ( _ = pb.New{{.Service.Name}}Client _ = io.Copy _ = errors.Wrap + _ = types.EmptyAny + _ = time.NewTimer ) // MakeHTTPHandler returns a handler that makes a set of endpoints available diff --git a/svcdef/svcdef.go b/svcdef/svcdef.go index cf1ae307..70a89694 100644 --- a/svcdef/svcdef.go +++ b/svcdef/svcdef.go @@ -523,6 +523,9 @@ func NewField(f *ast.Field) (*Field, error) { if oneof, ok := oneofs[ex.Name]; ok { rv.Type.Oneof = oneof } + case *ast.SelectorExpr: + packageIdent := ex.X.(*ast.Ident) + rv.Type.Name += fmt.Sprintf("%s.%s", packageIdent.Name, ex.Sel.Name) case *ast.StarExpr: rv.Type.StarExpr = true typeFollower(ex.X)