diff --git a/gengokit/httptransport/httptransport.go b/gengokit/httptransport/httptransport.go index b0d920f4..797938db 100644 --- a/gengokit/httptransport/httptransport.go +++ b/gengokit/httptransport/httptransport.go @@ -156,6 +156,10 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding { newField.IsEnum = field.Type.Enum != nil newField.ConvertFunc, newField.ConvertFuncNeedsErrorCheck = createDecodeConvertFunc(newField) newField.TypeConversion = createDecodeTypeConversion(newField) + if newField.Location == "body_root" { + newField.Location = "body" + nBinding.RequestRootField = &newField + } nBinding.Fields = append(nBinding.Fields, &newField) diff --git a/gengokit/httptransport/httptransport_test.go b/gengokit/httptransport/httptransport_test.go index ff9b9246..5aa3b7ad 100644 --- a/gengokit/httptransport/httptransport_test.go +++ b/gengokit/httptransport/httptransport_test.go @@ -21,6 +21,97 @@ func init() { gopath = filepath.SplitList(os.Getenv("GOPATH")) } +func TestNewMethodWithBody(t *testing.T) { + defStr := ` + syntax = "proto3"; + + // General package + package general; + + import "github.com/metaverse/truss/deftree/googlethirdparty/annotations.proto"; + + message Inner { + string a = 1; + } + + message SumRequest { + int64 a = 1; + Inner in = 2; + } + + message SumReply { + int64 v = 1; + string err = 2; + } + + service SumSvc { + rpc Sum(SumRequest) returns (SumReply) { + option (google.api.http) = { + put: "/sum/{a}" + body: "in" + }; + } + } + ` + sd, err := svcdef.NewFromString(defStr, gopath) + if err != nil { + t.Fatal(err, "Failed to create a service from the definition string") + } + innerField := Field{ + Name: "In", + QueryParamName: "in", + CamelName: "In", + LowCamelName: "in", + LocalName: "InSum", + Location: "body", + GoType: "pb.Inner", + ConvertFunc: "\nvar InSum *pb.Inner\nInSum = &pb.Inner{}\nerr = json.Unmarshal([]byte(InSumStr), InSum)\nif err != nil {\n\treturn nil, errors.Wrapf(err, \"couldn't decode InSum from %v\", InSumStr)\n}", + ConvertFuncNeedsErrorCheck: false, + TypeConversion: "InSum", + IsBaseType: false, + } + binding := &Binding{ + Label: "SumZero", + PathTemplate: "/sum/{a}", + BasePath: "/sum/", + Verb: "put", + RequestRootField: &innerField, + Fields: []*Field{ + &Field{ + Name: "A", + QueryParamName: "a", + CamelName: "A", + LowCamelName: "a", + LocalName: "ASum", + Location: "path", + GoType: "int64", + ConvertFunc: "ASum, err := strconv.ParseInt(ASumStr, 10, 64)", + ConvertFuncNeedsErrorCheck: true, + TypeConversion: "ASum", + IsBaseType: true, + }, + &innerField, + }, + } + + meth := &Method{ + Name: "Sum", + RequestType: "SumRequest", + ResponseType: "SumReply", + Bindings: []*Binding{ + binding, + }, + } + binding.Parent = meth + + newMeth := NewMethod(sd.Service.Methods[0]) + t.Logf("%v\n", spew.Sdump(sd.Service.Methods[0])) + if got, want := newMeth, meth; !reflect.DeepEqual(got, want) { + diff := gentesthelper.DiffStrings(spew.Sdump(got), spew.Sdump(want)) + t.Errorf("got != want; methods differ: %v\n", diff) + } +} + func TestNewMethod(t *testing.T) { defStr := ` syntax = "proto3"; diff --git a/gengokit/httptransport/templates/server.go b/gengokit/httptransport/templates/server.go index c81fdba3..743fe6be 100644 --- a/gengokit/httptransport/templates/server.go +++ b/gengokit/httptransport/templates/server.go @@ -9,7 +9,14 @@ var ServerDecodeTemplate = ` // body. Primarily useful in a server. func DecodeHTTP{{$binding.Label}}Request(_ context.Context, r *http.Request) (interface{}, error) { defer r.Body.Close() + var req pb.{{GoName $binding.Parent.RequestType}} + {{$req_field := "req" -}} + {{if $binding.RequestRootField -}} + {{$req_field = print "req" ($binding.RequestRootField.Name) -}} + var {{$req_field}} {{$binding.RequestRootField.GoType}} + {{end -}} + buf, err := ioutil.ReadAll(r.Body) if err != nil { return nil, errors.Wrapf(err, "cannot read body of http request") @@ -19,7 +26,7 @@ var ServerDecodeTemplate = ` unmarshaller := jsonpb.Unmarshaler{ AllowUnknownFields: true, } - if err = unmarshaller.Unmarshal(bytes.NewBuffer(buf), &req); err != nil { + if err = unmarshaller.Unmarshal(bytes.NewBuffer(buf), &{{$req_field}}); err != nil { const size = 8196 if len(buf) > size { buf = buf[:size] @@ -31,6 +38,10 @@ var ServerDecodeTemplate = ` } } + {{if $binding.RequestRootField}} + req.{{$binding.RequestRootField.Name}} = &{{$req_field}} + {{end}} + pathParams := mux.Vars(r) _ = pathParams diff --git a/gengokit/httptransport/templates_test.go b/gengokit/httptransport/templates_test.go index 3ea31c5e..7fcb7b18 100644 --- a/gengokit/httptransport/templates_test.go +++ b/gengokit/httptransport/templates_test.go @@ -151,6 +151,7 @@ func TestGenServerDecode(t *testing.T) { if err != nil { t.Errorf("Failed to generate server decode code: %v", err) } + desired := ` // DecodeHTTPSumZeroRequest is a transport/http.DecodeRequestFunc that @@ -158,6 +159,7 @@ func TestGenServerDecode(t *testing.T) { // body. Primarily useful in a server. func DecodeHTTPSumZeroRequest(_ context.Context, r *http.Request) (interface{}, error) { defer r.Body.Close() + var req pb.SumRequest buf, err := ioutil.ReadAll(r.Body) if err != nil { @@ -211,3 +213,132 @@ func DecodeHTTPSumZeroRequest(_ context.Context, r *http.Request) (interface{}, t.Log(gentesthelper.DiffStrings(got, want)) } } + +func TestGenServerDecodeWithBody(t *testing.T) { + innerField := &Field{ + Name: "c", + QueryParamName: "c", + CamelName: "C", + LowCamelName: "c", + LocalName: "CSum", + Location: "body", + GoType: "pb.Inner", + ConvertFunc: "", + ConvertFuncNeedsErrorCheck: true, + TypeConversion: "CSum", + IsBaseType: true, + } + binding := &Binding{ + Label: "SumZero", + PathTemplate: "/sum/{a}", + BasePath: "/sum/", + Verb: "get", + RequestRootField: innerField, + Fields: []*Field{ + &Field{ + Name: "a", + QueryParamName: "a", + CamelName: "A", + LowCamelName: "a", + LocalName: "ASum", + Location: "path", + GoType: "int64", + ConvertFunc: "ASum, err := strconv.ParseInt(ASumStr, 10, 64)", + ConvertFuncNeedsErrorCheck: true, + TypeConversion: "ASum", + IsBaseType: true, + }, + &Field{ + Name: "b", + QueryParamName: "b", + CamelName: "B", + LowCamelName: "b", + LocalName: "BSum", + Location: "query", + GoType: "int64", + ConvertFunc: "BSum, err := strconv.ParseInt(BSumStr, 10, 64)", + ConvertFuncNeedsErrorCheck: true, + TypeConversion: "BSum", + IsBaseType: true, + }, + innerField, + }, + } + meth := &Method{ + Name: "Sum", + RequestType: "SumRequest", + ResponseType: "SumReply", + Bindings: []*Binding{ + binding, + }, + } + binding.Parent = meth + + str, err := binding.GenServerDecode() + if err != nil { + t.Errorf("Failed to generate server decode code: %v", err) + } + desired := ` + +// DecodeHTTPSumZeroRequest is a transport/http.DecodeRequestFunc that +// decodes a JSON-encoded sum request from the HTTP request +// body. Primarily useful in a server. +func DecodeHTTPSumZeroRequest(_ context.Context, r *http.Request) (interface{}, error) { + defer r.Body.Close() + + var req pb.SumRequest + var reqc pb.Inner + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, errors.Wrapf(err, "cannot read body of http request") + } + if len(buf) > 0 { + // AllowUnknownFields stops the unmarshaler from failing if the JSON contains unknown fields. + unmarshaller := jsonpb.Unmarshaler{ + AllowUnknownFields: true, + } + if err = unmarshaller.Unmarshal(bytes.NewBuffer(buf), &reqc); err != nil { + const size = 8196 + if len(buf) > size { + buf = buf[:size] + } + return nil, httpError{errors.Wrapf(err, "request body '%s': cannot parse non-json request body", buf), + http.StatusBadRequest, + nil, + } + } + } + + req.c = &reqc + + pathParams := mux.Vars(r) + _ = pathParams + + queryParams := r.URL.Query() + _ = queryParams + + ASumStr := pathParams["a"] + ASum, err := strconv.ParseInt(ASumStr, 10, 64) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Error while extracting ASum from path, pathParams: %v", pathParams)) + } + req.A = ASum + + if BSumStrArr, ok := queryParams["b"]; ok { + BSumStr := BSumStrArr[0] + BSum, err := strconv.ParseInt(BSumStr, 10, 64) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Error while extracting BSum from query, queryParams: %v", queryParams)) + } + req.B = BSum + } + + return &req, err +} + +` + if got, want := strings.TrimSpace(str), strings.TrimSpace(desired); got != want { + t.Errorf("Generated code differs from result.\ngot = %s\nwant = %s", got, want) + t.Log(gentesthelper.DiffStrings(got, want)) + } +} diff --git a/gengokit/httptransport/types.go b/gengokit/httptransport/types.go index 55fc5bca..c8e6b0e1 100644 --- a/gengokit/httptransport/types.go +++ b/gengokit/httptransport/types.go @@ -28,6 +28,7 @@ type Binding struct { Verb string Fields []*Field OneofFields []*OneofField + RequestRootField *Field // A pointer back to the parent method of this binding. Used within some // binding methods Parent *Method diff --git a/gengokit/template/template.go b/gengokit/template/template.go index d251e337..1f1eec3b 100644 --- a/gengokit/template/template.go +++ b/gengokit/template/template.go @@ -425,31 +425,31 @@ type bintree struct { } var _bintree = &bintree{nil, map[string]*bintree{ - "cmd": &bintree{nil, map[string]*bintree{ - "NAME": &bintree{nil, map[string]*bintree{ - "main.gotemplate": &bintree{cmdNameMainGotemplate, map[string]*bintree{}}, + "cmd": {nil, map[string]*bintree{ + "NAME": {nil, map[string]*bintree{ + "main.gotemplate": {cmdNameMainGotemplate, map[string]*bintree{}}, }}, }}, - "handlers": &bintree{nil, map[string]*bintree{ - "handlers.gotemplate": &bintree{handlersHandlersGotemplate, map[string]*bintree{}}, - "hooks.gotemplate": &bintree{handlersHooksGotemplate, map[string]*bintree{}}, - "middlewares.gotemplate": &bintree{handlersMiddlewaresGotemplate, map[string]*bintree{}}, + "handlers": {nil, map[string]*bintree{ + "handlers.gotemplate": {handlersHandlersGotemplate, map[string]*bintree{}}, + "hooks.gotemplate": {handlersHooksGotemplate, map[string]*bintree{}}, + "middlewares.gotemplate": {handlersMiddlewaresGotemplate, map[string]*bintree{}}, }}, - "svc": &bintree{nil, map[string]*bintree{ - "client": &bintree{nil, map[string]*bintree{ - "grpc": &bintree{nil, map[string]*bintree{ - "client.gotemplate": &bintree{svcClientGrpcClientGotemplate, map[string]*bintree{}}, + "svc": {nil, map[string]*bintree{ + "client": {nil, map[string]*bintree{ + "grpc": {nil, map[string]*bintree{ + "client.gotemplate": {svcClientGrpcClientGotemplate, map[string]*bintree{}}, }}, - "http": &bintree{nil, map[string]*bintree{ - "client.gotemplate": &bintree{svcClientHttpClientGotemplate, map[string]*bintree{}}, + "http": {nil, map[string]*bintree{ + "client.gotemplate": {svcClientHttpClientGotemplate, map[string]*bintree{}}, }}, }}, - "endpoints.gotemplate": &bintree{svcEndpointsGotemplate, map[string]*bintree{}}, - "server": &bintree{nil, map[string]*bintree{ - "run.gotemplate": &bintree{svcServerRunGotemplate, map[string]*bintree{}}, + "endpoints.gotemplate": {svcEndpointsGotemplate, map[string]*bintree{}}, + "server": {nil, map[string]*bintree{ + "run.gotemplate": {svcServerRunGotemplate, map[string]*bintree{}}, }}, - "transport_grpc.gotemplate": &bintree{svcTransport_grpcGotemplate, map[string]*bintree{}}, - "transport_http.gotemplate": &bintree{svcTransport_httpGotemplate, map[string]*bintree{}}, + "transport_grpc.gotemplate": {svcTransport_grpcGotemplate, map[string]*bintree{}}, + "transport_http.gotemplate": {svcTransport_httpGotemplate, map[string]*bintree{}}, }}, }} diff --git a/svcdef/consolidate_http.go b/svcdef/consolidate_http.go index 84efe384..a343cfd3 100644 --- a/svcdef/consolidate_http.go +++ b/svcdef/consolidate_http.go @@ -148,11 +148,11 @@ func paramLocation(field *Field, binding *svcparse.HTTPBinding) string { if optField.Value == "*" { return "body" } else if optField.Value == field.Name { - return "body" + return "body_root" // Have to CamelCase the fields from the protobuf file, as they may // be lowercase while the name from the Go file will be CamelCased. } else if gogen.CamelCase(strings.Split(optField.Value, ".")[0]) == field.Name { - return "body" + return "body_root" } } } diff --git a/svcdef/consolidate_http_test.go b/svcdef/consolidate_http_test.go index 484cec91..340a07d6 100644 --- a/svcdef/consolidate_http_test.go +++ b/svcdef/consolidate_http_test.go @@ -124,7 +124,7 @@ service Map { Name string Location string }{ - {"A", "body"}, + {"A", "body_root"}, {"AA", "query"}, {"C", "query"}, {"MapField", "query"},