diff --git a/cmd/protoc-gen-go-rerpc/rerpc.go b/cmd/protoc-gen-go-rerpc/rerpc.go index 5f903c82..61cb57da 100644 --- a/cmd/protoc-gen-go-rerpc/rerpc.go +++ b/cmd/protoc-gen-go-rerpc/rerpc.go @@ -100,10 +100,14 @@ func service(file *protogen.File, g *protogen.GeneratedFile, service *protogen.S serverConstructor(g, service, serverName) serverImplementation(g, service, serverName) - // client stream types + clientStreams(g, service, clientName) serverStreams(g, service, serverName) } +func clientStreamName(cname string, method *protogen.Method) string { + return cname + "_" + method.GoName +} + func clientInterface(g *protogen.GeneratedFile, service *protogen.Service, name string) { comment(g, name, " is a client for the ", service.Desc.FullName(), " service.") if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { @@ -112,18 +116,32 @@ func clientInterface(g *protogen.GeneratedFile, service *protogen.Service, name } g.Annotate(name, service.Location) g.P("type ", name, " interface {") - for _, method := range unaryMethods(service) { + for _, method := range service.Methods { g.Annotate(name+"."+method.GoName, method.Location) - g.P(method.Comments.Leading, clientSignature(g, method)) + g.P(method.Comments.Leading, clientSignature(g, name, method)) } g.P("}") g.P() } -func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string { +func clientSignature(g *protogen.GeneratedFile, cname string, method *protogen.Method) string { if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { deprecated(g) } + if method.Desc.IsStreamingClient() { + // client and bidi streaming + return method.GoName + "(ctx " + g.QualifiedGoIdent(contextContext) + + ", opts ..." + g.QualifiedGoIdent(rerpcPackage.Ident("CallOption")) + ") " + + "*" + clientStreamName(cname, method) + } + if method.Desc.IsStreamingServer() { + // server streaming + return method.GoName + "(ctx " + g.QualifiedGoIdent(contextContext) + + ", req *" + g.QualifiedGoIdent(method.Input.GoIdent) + + ", opts ..." + g.QualifiedGoIdent(rerpcPackage.Ident("CallOption")) + ") " + + "(*" + clientStreamName(cname, method) + ", error)" + } + // unary return method.GoName + "(ctx " + g.QualifiedGoIdent(contextContext) + ", req *" + g.QualifiedGoIdent(method.Input.GoIdent) + ", opts ..." + g.QualifiedGoIdent(rerpcPackage.Ident("CallOption")) + ") " + @@ -173,24 +191,35 @@ func clientImplementation(g *protogen.GeneratedFile, service *protogen.Service, g.P() // Client method implementations. - for _, method := range unaryMethods(service) { - clientMethod(g, service, method) + for _, method := range service.Methods { + clientMethod(g, service, name, method) } } -func clientMethod(g *protogen.GeneratedFile, service *protogen.Service, method *protogen.Method) { +func clientMethod(g *protogen.GeneratedFile, service *protogen.Service, cname string, method *protogen.Method) { + isStreamingClient := method.Desc.IsStreamingClient() + isStreamingServer := method.Desc.IsStreamingServer() comment(g, method.GoName, " calls ", method.Desc.FullName(), ".", " Call options passed here apply only to this call.") if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { g.P("//") deprecated(g) } - g.P("func (c *", unexport(method.Parent.GoName), "ClientReRPC) ", clientSignature(g, method), "{") + g.P("func (c *", unexport(method.Parent.GoName), "ClientReRPC) ", clientSignature(g, cname, method), " {") g.P("merged := c.mergeOptions(opts)") + g.P("ic := ", rerpcPackage.Ident("ConfiguredCallInterceptor"), "(merged...)") g.P("ctx, call := ", rerpcPackage.Ident("NewCall"), "(") g.P("ctx,") g.P("c.doer,") - g.P(rerpcPackage.Ident("StreamTypeUnary"), ",") + if isStreamingClient && isStreamingServer { + g.P(rerpcPackage.Ident("StreamTypeBidirectional"), ",") + } else if isStreamingClient { + g.P(rerpcPackage.Ident("StreamTypeClient"), ",") + } else if isStreamingServer { + g.P(rerpcPackage.Ident("StreamTypeServer"), ",") + } else { + g.P(rerpcPackage.Ident("StreamTypeUnary"), ",") + } g.P("c.baseURL,") g.P(`"`, service.Desc.ParentFile().Package(), `", // protobuf package`) g.P(`"`, service.Desc.Name(), `", // protobuf service`) @@ -198,6 +227,31 @@ func clientMethod(g *protogen.GeneratedFile, service *protogen.Service, method * g.P("merged...,") g.P(")") + if isStreamingClient || isStreamingServer { + g.P("if ic != nil {") + g.P("call = ic.WrapStream(call)") + g.P("}") + g.P("stream := call(ctx)") + if !isStreamingClient && isStreamingServer { + // server streaming, we need to send the request. + g.P("if err := stream.Send(req); err != nil {") + g.P("_ = stream.CloseSend(err)") + g.P("_ = stream.CloseReceive()") + g.P("return nil, err") + g.P("}") + g.P("if err := stream.CloseSend(nil); err != nil {") + g.P("_ = stream.CloseReceive()") + g.P("return nil, err") + g.P("}") + g.P("return New", clientStreamName(cname, method), "(stream), nil") + } else { + g.P("return New", clientStreamName(cname, method), "(stream)") + } + g.P("}") + g.P() + return + } + g.P("wrapped := ", rerpcPackage.Ident("Func"), "(func(ctx ", contextContext, ", msg ", protoMessage, ") (", protoMessage, ", error) {") g.P("stream := call(ctx)") g.P("if err := stream.Send(req); err != nil {") @@ -217,7 +271,7 @@ func clientMethod(g *protogen.GeneratedFile, service *protogen.Service, method * g.P("return &res, stream.CloseReceive()") g.P("})") - g.P("if ic := ", rerpcPackage.Ident("ConfiguredCallInterceptor"), "(merged...); ic != nil {") + g.P("if ic != nil {") g.P("wrapped = ic.Wrap(wrapped)") g.P("}") g.P("res, err := wrapped(ctx, req)") @@ -233,8 +287,74 @@ func clientMethod(g *protogen.GeneratedFile, service *protogen.Service, method * g.P() } +func clientStreams(g *protogen.GeneratedFile, service *protogen.Service, name string) { + for _, method := range service.Methods { + if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { + continue + } + streamName := clientStreamName(name, method) + isDeprecated := method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() + comment(g, streamName, " is the client-side stream for the ", method.Desc.FullName(), " procedure.") + if isDeprecated { + g.P("//") + deprecated(g) + } + g.P("type ", streamName, " struct {") + g.P("stream ", rerpcPackage.Ident("Stream")) + g.P("}") + g.P() + g.P("func New", streamName, "(stream ", rerpcPackage.Ident("Stream"), ") *", streamName, " {") + g.P("return &", streamName, "{stream}") + g.P("}") + g.P() + + if method.Desc.IsStreamingClient() { + g.P("func (s *", streamName, ") Send(msg *", method.Input.GoIdent, ") error {") + g.P("return s.stream.Send(msg)") + g.P("}") + g.P() + if method.Desc.IsStreamingServer() { + // Bidi: otherwise, we'll never know when they're done sending. + g.P("func (s *", streamName, ") CloseSend() error {") + g.P("return s.stream.CloseSend(nil)") + g.P("}") + g.P() + } else { + // Client-only streaming. + g.P("func (s *", streamName, ") ReceiveAndClose() (*", method.Output.GoIdent, ", error) {") + g.P("if err := s.stream.CloseSend(nil); err != nil {") + g.P("return nil, err") + g.P("}") + g.P("var res ", method.Output.GoIdent) + g.P("err := s.stream.Receive(&res)") + g.P("return &res, err") + g.P("}") + g.P() + } + } + if method.Desc.IsStreamingServer() { + g.P("func (s *", streamName, ") Receive() (*", method.Output.GoIdent, ", error) {") + g.P("var req ", method.Output.GoIdent) + g.P("if err := s.stream.Receive(&req); err != nil {") + g.P("return nil, err") + g.P("}") + g.P("return &req, nil") + g.P("}") + g.P() + closeName := "Close" + if method.Desc.IsStreamingClient() { + closeName = "CloseReceive" + } + g.P("func (s *", streamName, ") ", closeName, "() error {") + g.P("return s.stream.CloseReceive()") + g.P("}") + g.P() + } + } +} + func serverStreamName(sname string, method *protogen.Method) string { - return sname + "_" + method.GoName + "Server" + return sname + "_" + method.GoName } func serverInterface(g *protogen.GeneratedFile, service *protogen.Service, name string) { @@ -487,14 +607,3 @@ func serverStreams(g *protogen.GeneratedFile, service *protogen.Service, sname s } func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] } - -func unaryMethods(service *protogen.Service) []*protogen.Method { - unary := make([]*protogen.Method, 0, len(service.Methods)) - for _, m := range service.Methods { - if m.Desc.IsStreamingServer() || m.Desc.IsStreamingClient() { - continue - } - unary = append(unary, m) - } - return unary -} diff --git a/internal/crosstest/v1test/cross_rerpc.pb.go b/internal/crosstest/v1test/cross_rerpc.pb.go index 15c5afa1..6e55231c 100644 --- a/internal/crosstest/v1test/cross_rerpc.pb.go +++ b/internal/crosstest/v1test/cross_rerpc.pb.go @@ -65,6 +65,7 @@ func (c *crossServiceClientReRPC) mergeOptions(opts []rerpc.CallOption) []rerpc. // here apply only to this call. func (c *crossServiceClientReRPC) Ping(ctx context.Context, req *PingRequest, opts ...rerpc.CallOption) (*PingResponse, error) { merged := c.mergeOptions(opts) + ic := rerpc.ConfiguredCallInterceptor(merged...) ctx, call := rerpc.NewCall( ctx, c.doer, @@ -93,7 +94,7 @@ func (c *crossServiceClientReRPC) Ping(ctx context.Context, req *PingRequest, op } return &res, stream.CloseReceive() }) - if ic := rerpc.ConfiguredCallInterceptor(merged...); ic != nil { + if ic != nil { wrapped = ic.Wrap(wrapped) } res, err := wrapped(ctx, req) @@ -111,6 +112,7 @@ func (c *crossServiceClientReRPC) Ping(ctx context.Context, req *PingRequest, op // here apply only to this call. func (c *crossServiceClientReRPC) Fail(ctx context.Context, req *FailRequest, opts ...rerpc.CallOption) (*FailResponse, error) { merged := c.mergeOptions(opts) + ic := rerpc.ConfiguredCallInterceptor(merged...) ctx, call := rerpc.NewCall( ctx, c.doer, @@ -139,7 +141,7 @@ func (c *crossServiceClientReRPC) Fail(ctx context.Context, req *FailRequest, op } return &res, stream.CloseReceive() }) - if ic := rerpc.ConfiguredCallInterceptor(merged...); ic != nil { + if ic != nil { wrapped = ic.Wrap(wrapped) } res, err := wrapped(ctx, req) diff --git a/internal/ping/v1test/ping_rerpc.pb.go b/internal/ping/v1test/ping_rerpc.pb.go index c0fa6385..2442232e 100644 --- a/internal/ping/v1test/ping_rerpc.pb.go +++ b/internal/ping/v1test/ping_rerpc.pb.go @@ -28,6 +28,9 @@ const _ = rerpc.SupportsCodeGenV0 // requires reRPC v0.0.1 or later type PingServiceClientReRPC interface { Ping(ctx context.Context, req *PingRequest, opts ...rerpc.CallOption) (*PingResponse, error) Fail(ctx context.Context, req *FailRequest, opts ...rerpc.CallOption) (*FailResponse, error) + Sum(ctx context.Context, opts ...rerpc.CallOption) *PingServiceClientReRPC_Sum + CountUp(ctx context.Context, req *CountUpRequest, opts ...rerpc.CallOption) (*PingServiceClientReRPC_CountUp, error) + CumSum(ctx context.Context, opts ...rerpc.CallOption) *PingServiceClientReRPC_CumSum } type pingServiceClientReRPC struct { @@ -65,6 +68,7 @@ func (c *pingServiceClientReRPC) mergeOptions(opts []rerpc.CallOption) []rerpc.C // apply only to this call. func (c *pingServiceClientReRPC) Ping(ctx context.Context, req *PingRequest, opts ...rerpc.CallOption) (*PingResponse, error) { merged := c.mergeOptions(opts) + ic := rerpc.ConfiguredCallInterceptor(merged...) ctx, call := rerpc.NewCall( ctx, c.doer, @@ -93,7 +97,7 @@ func (c *pingServiceClientReRPC) Ping(ctx context.Context, req *PingRequest, opt } return &res, stream.CloseReceive() }) - if ic := rerpc.ConfiguredCallInterceptor(merged...); ic != nil { + if ic != nil { wrapped = ic.Wrap(wrapped) } res, err := wrapped(ctx, req) @@ -111,6 +115,7 @@ func (c *pingServiceClientReRPC) Ping(ctx context.Context, req *PingRequest, opt // apply only to this call. func (c *pingServiceClientReRPC) Fail(ctx context.Context, req *FailRequest, opts ...rerpc.CallOption) (*FailResponse, error) { merged := c.mergeOptions(opts) + ic := rerpc.ConfiguredCallInterceptor(merged...) ctx, call := rerpc.NewCall( ctx, c.doer, @@ -139,7 +144,7 @@ func (c *pingServiceClientReRPC) Fail(ctx context.Context, req *FailRequest, opt } return &res, stream.CloseReceive() }) - if ic := rerpc.ConfiguredCallInterceptor(merged...); ic != nil { + if ic != nil { wrapped = ic.Wrap(wrapped) } res, err := wrapped(ctx, req) @@ -153,6 +158,81 @@ func (c *pingServiceClientReRPC) Fail(ctx context.Context, req *FailRequest, opt return typed, nil } +// Sum calls internal.ping.v1test.PingService.Sum. Call options passed here +// apply only to this call. +func (c *pingServiceClientReRPC) Sum(ctx context.Context, opts ...rerpc.CallOption) *PingServiceClientReRPC_Sum { + merged := c.mergeOptions(opts) + ic := rerpc.ConfiguredCallInterceptor(merged...) + ctx, call := rerpc.NewCall( + ctx, + c.doer, + rerpc.StreamTypeClient, + c.baseURL, + "internal.ping.v1test", // protobuf package + "PingService", // protobuf service + "Sum", // protobuf method + merged..., + ) + if ic != nil { + call = ic.WrapStream(call) + } + stream := call(ctx) + return NewPingServiceClientReRPC_Sum(stream) +} + +// CountUp calls internal.ping.v1test.PingService.CountUp. Call options passed +// here apply only to this call. +func (c *pingServiceClientReRPC) CountUp(ctx context.Context, req *CountUpRequest, opts ...rerpc.CallOption) (*PingServiceClientReRPC_CountUp, error) { + merged := c.mergeOptions(opts) + ic := rerpc.ConfiguredCallInterceptor(merged...) + ctx, call := rerpc.NewCall( + ctx, + c.doer, + rerpc.StreamTypeServer, + c.baseURL, + "internal.ping.v1test", // protobuf package + "PingService", // protobuf service + "CountUp", // protobuf method + merged..., + ) + if ic != nil { + call = ic.WrapStream(call) + } + stream := call(ctx) + if err := stream.Send(req); err != nil { + _ = stream.CloseSend(err) + _ = stream.CloseReceive() + return nil, err + } + if err := stream.CloseSend(nil); err != nil { + _ = stream.CloseReceive() + return nil, err + } + return NewPingServiceClientReRPC_CountUp(stream), nil +} + +// CumSum calls internal.ping.v1test.PingService.CumSum. Call options passed +// here apply only to this call. +func (c *pingServiceClientReRPC) CumSum(ctx context.Context, opts ...rerpc.CallOption) *PingServiceClientReRPC_CumSum { + merged := c.mergeOptions(opts) + ic := rerpc.ConfiguredCallInterceptor(merged...) + ctx, call := rerpc.NewCall( + ctx, + c.doer, + rerpc.StreamTypeBidirectional, + c.baseURL, + "internal.ping.v1test", // protobuf package + "PingService", // protobuf service + "CumSum", // protobuf method + merged..., + ) + if ic != nil { + call = ic.WrapStream(call) + } + stream := call(ctx) + return NewPingServiceClientReRPC_CumSum(stream) +} + // PingServiceReRPC is a server for the internal.ping.v1test.PingService // service. To make sure that adding methods to this protobuf service doesn't // break all implementations of this interface, all implementations must embed @@ -164,9 +244,9 @@ func (c *pingServiceClientReRPC) Fail(ctx context.Context, req *FailRequest, opt type PingServiceReRPC interface { Ping(context.Context, *PingRequest) (*PingResponse, error) Fail(context.Context, *FailRequest) (*FailResponse, error) - Sum(context.Context, *PingServiceReRPC_SumServer) error - CountUp(context.Context, *CountUpRequest, *PingServiceReRPC_CountUpServer) error - CumSum(context.Context, *PingServiceReRPC_CumSumServer) error + Sum(context.Context, *PingServiceReRPC_Sum) error + CountUp(context.Context, *CountUpRequest, *PingServiceReRPC_CountUp) error + CumSum(context.Context, *PingServiceReRPC_CumSum) error mustEmbedUnimplementedPingServiceReRPC() } @@ -300,7 +380,7 @@ func NewPingServiceHandlerReRPC(svc PingServiceReRPC, opts ...rerpc.HandlerOptio sf = ic.WrapStream(sf) } stream := sf(ctx) - typed := NewPingServiceReRPC_SumServer(stream) + typed := NewPingServiceReRPC_Sum(stream) err := svc.Sum(stream.Context(), typed) _ = stream.CloseReceive() if err != nil { @@ -329,7 +409,7 @@ func NewPingServiceHandlerReRPC(svc PingServiceReRPC, opts ...rerpc.HandlerOptio sf = ic.WrapStream(sf) } stream := sf(ctx) - typed := NewPingServiceReRPC_CountUpServer(stream) + typed := NewPingServiceReRPC_CountUp(stream) var req CountUpRequest if err := stream.Receive(&req); err != nil { _ = stream.CloseReceive() @@ -367,7 +447,7 @@ func NewPingServiceHandlerReRPC(svc PingServiceReRPC, opts ...rerpc.HandlerOptio sf = ic.WrapStream(sf) } stream := sf(ctx) - typed := NewPingServiceReRPC_CumSumServer(stream) + typed := NewPingServiceReRPC_CumSum(stream) err := svc.CumSum(stream.Context(), typed) _ = stream.CloseReceive() if err != nil { @@ -407,31 +487,106 @@ func (UnimplementedPingServiceReRPC) Fail(context.Context, *FailRequest) (*FailR return nil, rerpc.Errorf(rerpc.CodeUnimplemented, "internal.ping.v1test.PingService.Fail isn't implemented") } -func (UnimplementedPingServiceReRPC) Sum(context.Context, *PingServiceReRPC_SumServer) error { +func (UnimplementedPingServiceReRPC) Sum(context.Context, *PingServiceReRPC_Sum) error { return rerpc.Errorf(rerpc.CodeUnimplemented, "internal.ping.v1test.PingService.Sum isn't implemented") } -func (UnimplementedPingServiceReRPC) CountUp(context.Context, *CountUpRequest, *PingServiceReRPC_CountUpServer) error { +func (UnimplementedPingServiceReRPC) CountUp(context.Context, *CountUpRequest, *PingServiceReRPC_CountUp) error { return rerpc.Errorf(rerpc.CodeUnimplemented, "internal.ping.v1test.PingService.CountUp isn't implemented") } -func (UnimplementedPingServiceReRPC) CumSum(context.Context, *PingServiceReRPC_CumSumServer) error { +func (UnimplementedPingServiceReRPC) CumSum(context.Context, *PingServiceReRPC_CumSum) error { return rerpc.Errorf(rerpc.CodeUnimplemented, "internal.ping.v1test.PingService.CumSum isn't implemented") } func (UnimplementedPingServiceReRPC) mustEmbedUnimplementedPingServiceReRPC() {} -// PingServiceReRPC_SumServer is the server-side stream for the +// PingServiceClientReRPC_Sum is the client-side stream for the +// internal.ping.v1test.PingService.Sum procedure. +type PingServiceClientReRPC_Sum struct { + stream rerpc.Stream +} + +func NewPingServiceClientReRPC_Sum(stream rerpc.Stream) *PingServiceClientReRPC_Sum { + return &PingServiceClientReRPC_Sum{stream} +} + +func (s *PingServiceClientReRPC_Sum) Send(msg *SumRequest) error { + return s.stream.Send(msg) +} + +func (s *PingServiceClientReRPC_Sum) ReceiveAndClose() (*SumResponse, error) { + if err := s.stream.CloseSend(nil); err != nil { + return nil, err + } + var res SumResponse + err := s.stream.Receive(&res) + return &res, err +} + +// PingServiceClientReRPC_CountUp is the client-side stream for the +// internal.ping.v1test.PingService.CountUp procedure. +type PingServiceClientReRPC_CountUp struct { + stream rerpc.Stream +} + +func NewPingServiceClientReRPC_CountUp(stream rerpc.Stream) *PingServiceClientReRPC_CountUp { + return &PingServiceClientReRPC_CountUp{stream} +} + +func (s *PingServiceClientReRPC_CountUp) Receive() (*CountUpResponse, error) { + var req CountUpResponse + if err := s.stream.Receive(&req); err != nil { + return nil, err + } + return &req, nil +} + +func (s *PingServiceClientReRPC_CountUp) Close() error { + return s.stream.CloseReceive() +} + +// PingServiceClientReRPC_CumSum is the client-side stream for the +// internal.ping.v1test.PingService.CumSum procedure. +type PingServiceClientReRPC_CumSum struct { + stream rerpc.Stream +} + +func NewPingServiceClientReRPC_CumSum(stream rerpc.Stream) *PingServiceClientReRPC_CumSum { + return &PingServiceClientReRPC_CumSum{stream} +} + +func (s *PingServiceClientReRPC_CumSum) Send(msg *CumSumRequest) error { + return s.stream.Send(msg) +} + +func (s *PingServiceClientReRPC_CumSum) CloseSend() error { + return s.stream.CloseSend(nil) +} + +func (s *PingServiceClientReRPC_CumSum) Receive() (*CumSumResponse, error) { + var req CumSumResponse + if err := s.stream.Receive(&req); err != nil { + return nil, err + } + return &req, nil +} + +func (s *PingServiceClientReRPC_CumSum) CloseReceive() error { + return s.stream.CloseReceive() +} + +// PingServiceReRPC_Sum is the server-side stream for the // internal.ping.v1test.PingService.Sum procedure. -type PingServiceReRPC_SumServer struct { +type PingServiceReRPC_Sum struct { stream rerpc.Stream } -func NewPingServiceReRPC_SumServer(stream rerpc.Stream) *PingServiceReRPC_SumServer { - return &PingServiceReRPC_SumServer{stream} +func NewPingServiceReRPC_Sum(stream rerpc.Stream) *PingServiceReRPC_Sum { + return &PingServiceReRPC_Sum{stream} } -func (s *PingServiceReRPC_SumServer) Receive() (*SumRequest, error) { +func (s *PingServiceReRPC_Sum) Receive() (*SumRequest, error) { var req SumRequest if err := s.stream.Receive(&req); err != nil { return nil, err @@ -439,38 +594,38 @@ func (s *PingServiceReRPC_SumServer) Receive() (*SumRequest, error) { return &req, nil } -func (s *PingServiceReRPC_SumServer) SendAndClose(msg *SumResponse) error { +func (s *PingServiceReRPC_Sum) SendAndClose(msg *SumResponse) error { if err := s.stream.CloseReceive(); err != nil { return err } return s.stream.Send(msg) } -// PingServiceReRPC_CountUpServer is the server-side stream for the +// PingServiceReRPC_CountUp is the server-side stream for the // internal.ping.v1test.PingService.CountUp procedure. -type PingServiceReRPC_CountUpServer struct { +type PingServiceReRPC_CountUp struct { stream rerpc.Stream } -func NewPingServiceReRPC_CountUpServer(stream rerpc.Stream) *PingServiceReRPC_CountUpServer { - return &PingServiceReRPC_CountUpServer{stream} +func NewPingServiceReRPC_CountUp(stream rerpc.Stream) *PingServiceReRPC_CountUp { + return &PingServiceReRPC_CountUp{stream} } -func (s *PingServiceReRPC_CountUpServer) Send(msg *CountUpResponse) error { +func (s *PingServiceReRPC_CountUp) Send(msg *CountUpResponse) error { return s.stream.Send(msg) } -// PingServiceReRPC_CumSumServer is the server-side stream for the +// PingServiceReRPC_CumSum is the server-side stream for the // internal.ping.v1test.PingService.CumSum procedure. -type PingServiceReRPC_CumSumServer struct { +type PingServiceReRPC_CumSum struct { stream rerpc.Stream } -func NewPingServiceReRPC_CumSumServer(stream rerpc.Stream) *PingServiceReRPC_CumSumServer { - return &PingServiceReRPC_CumSumServer{stream} +func NewPingServiceReRPC_CumSum(stream rerpc.Stream) *PingServiceReRPC_CumSum { + return &PingServiceReRPC_CumSum{stream} } -func (s *PingServiceReRPC_CumSumServer) Receive() (*CumSumRequest, error) { +func (s *PingServiceReRPC_CumSum) Receive() (*CumSumRequest, error) { var req CumSumRequest if err := s.stream.Receive(&req); err != nil { return nil, err @@ -478,6 +633,6 @@ func (s *PingServiceReRPC_CumSumServer) Receive() (*CumSumRequest, error) { return &req, nil } -func (s *PingServiceReRPC_CumSumServer) Send(msg *CumSumResponse) error { +func (s *PingServiceReRPC_CumSum) Send(msg *CumSumResponse) error { return s.stream.Send(msg) } diff --git a/rerpc_test.go b/rerpc_test.go index 3ad6f4e0..051e7701 100644 --- a/rerpc_test.go +++ b/rerpc_test.go @@ -12,6 +12,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "google.golang.org/protobuf/proto" @@ -38,7 +39,7 @@ func (p pingServer) Fail(ctx context.Context, req *pingpb.FailRequest) (*pingpb. return nil, rerpc.Errorf(rerpc.Code(req.Code), errMsg) } -func (p pingServer) Sum(ctx context.Context, stream *pingpb.PingServiceReRPC_SumServer) error { +func (p pingServer) Sum(ctx context.Context, stream *pingpb.PingServiceReRPC_Sum) error { var sum int64 for { if err := ctx.Err(); err != nil { @@ -56,7 +57,7 @@ func (p pingServer) Sum(ctx context.Context, stream *pingpb.PingServiceReRPC_Sum } } -func (p pingServer) CountUp(ctx context.Context, req *pingpb.CountUpRequest, stream *pingpb.PingServiceReRPC_CountUpServer) error { +func (p pingServer) CountUp(ctx context.Context, req *pingpb.CountUpRequest, stream *pingpb.PingServiceReRPC_CountUp) error { if req.Number <= 0 { return rerpc.Errorf(rerpc.CodeInvalidArgument, "number must be positive: got %v", req.Number) } @@ -71,7 +72,7 @@ func (p pingServer) CountUp(ctx context.Context, req *pingpb.CountUpRequest, str return nil } -func (p pingServer) CumSum(ctx context.Context, stream *pingpb.PingServiceReRPC_CumSumServer) error { +func (p pingServer) CumSum(ctx context.Context, stream *pingpb.PingServiceReRPC_CumSum) error { var sum int64 for { if err := ctx.Err(); err != nil { @@ -325,6 +326,87 @@ func TestServerProtoGRPC(t *testing.T) { assert.Equal(t, res, expect, "ping response") }) } + testSum := func(t *testing.T, client pingpb.PingServiceClientReRPC) { + t.Run("sum", func(t *testing.T) { + const upTo = 10 + const expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 + stream := client.Sum(context.Background()) + for i := int64(1); i <= upTo; i++ { + err := stream.Send(&pingpb.SumRequest{Number: i}) + assert.Nil(t, err, "Send %v", assert.Fmt(i)) + } + res, err := stream.ReceiveAndClose() + assert.Nil(t, err, "ReceiveAndClose error") + assert.Equal(t, res, &pingpb.SumResponse{Sum: 55}, "response") + }) + } + testCountUp := func(t *testing.T, client pingpb.PingServiceClientReRPC) { + t.Run("count_up", func(t *testing.T) { + const n = 5 + got := make([]int64, 0, n) + expect := make([]int64, 0, n) + for i := 1; i <= n; i++ { + expect = append(expect, int64(i)) + } + stream, err := client.CountUp( + context.Background(), + &pingpb.CountUpRequest{Number: n}, + ) + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err, "receive error") + got = append(got, msg.Number) + } + err = stream.Close() + assert.Nil(t, err, "close error") + assert.Equal(t, got, expect, "responses") + }) + } + testCumSum := func(t *testing.T, client pingpb.PingServiceClientReRPC, expectSuccess bool) { + t.Run("cumsum", func(t *testing.T) { + send := []int64{3, 5, 1} + expect := []int64{3, 8, 9} + var got []int64 + stream := client.CumSum(context.Background()) + if !expectSuccess { + err := stream.Send(&pingpb.CumSumRequest{}) + assert.Nil(t, err, "first send on HTTP/1.1") // succeeds, haven't gotten response back yet + assert.Nil(t, stream.CloseSend(), "close send error on HTTP/1.1") + _, err = stream.Receive() + assert.NotNil(t, err, "first receive on HTTP/1.1") // should be 505 + assert.True(t, strings.Contains(err.Error(), "HTTP status 505"), "expected 505, got %v", assert.Fmt(err)) + assert.Nil(t, stream.CloseReceive(), "close receive error on HTTP/1.1") + return + } + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i, n := range send { + err := stream.Send(&pingpb.CumSumRequest{Number: n}) + assert.Nil(t, err, "send error #%v", assert.Fmt(i)) + } + assert.Nil(t, stream.CloseSend(), "close send error") + }() + go func() { + defer wg.Done() + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err, "receive error") + got = append(got, msg.Sum) + } + assert.Nil(t, stream.CloseReceive(), "close receive error") + }() + wg.Wait() + assert.Equal(t, got, expect, "sums") + }) + } testErrors := func(t *testing.T, client pingpb.PingServiceClientReRPC) { t.Run("errors", func(t *testing.T) { req := &pingpb.FailRequest{Code: int32(rerpc.CodeResourceExhausted)} @@ -543,10 +625,13 @@ func TestServerProtoGRPC(t *testing.T) { assert.Equal(t, res, expect, "response") }) } - testMatrix := func(t *testing.T, server *httptest.Server) { + testMatrix := func(t *testing.T, server *httptest.Server, bidi bool) { t.Run("identity", func(t *testing.T) { client := pingpb.NewPingServiceClientReRPC(server.URL, server.Client()) testPing(t, client) + testSum(t, client) + testCountUp(t, client) + testCumSum(t, client, bidi) testErrors(t, client) testHealth(t, server.URL, server.Client()) }) @@ -557,6 +642,9 @@ func TestServerProtoGRPC(t *testing.T) { rerpc.Gzip(true), ) testPing(t, client) + testSum(t, client) + testCountUp(t, client) + testCumSum(t, client, bidi) testErrors(t, client) testHealth(t, server.URL, server.Client(), rerpc.Gzip(true)) }) @@ -565,14 +653,14 @@ func TestServerProtoGRPC(t *testing.T) { t.Run("http1", func(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - testMatrix(t, server) + testMatrix(t, server, false /* bidi */) }) t.Run("http2", func(t *testing.T) { server := httptest.NewUnstartedServer(mux) server.EnableHTTP2 = true server.StartTLS() defer server.Close() - testMatrix(t, server) + testMatrix(t, server, true /* bidi */) testReflection(t, server.URL, server.Client()) }) }