Skip to content

Commit

Permalink
Generate client-side streaming code
Browse files Browse the repository at this point in the history
Generate the client-side streaming methods and types. Now that we've got
both sides set up, also add some reRPC-on-reRPC tests. (We'll add
cross-tests in a future PR.)

This is a big step closer to solving #1.
  • Loading branch information
akshayjshah committed Feb 28, 2022
1 parent b6b5c6f commit dc72cf6
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 58 deletions.
153 changes: 131 additions & 22 deletions cmd/protoc-gen-go-rerpc/rerpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,14 @@ func service(file *protogen.File, g *protogen.GeneratedFile, service *protogen.S
serverConstructor(g, service, serverName)
serverImplementation(g, service, serverName)

// client stream types
clientStreams(g, service, clientName)
serverStreams(g, service, serverName)
}

func clientStreamName(cname string, method *protogen.Method) string {
return cname + "_" + method.GoName
}

func clientInterface(g *protogen.GeneratedFile, service *protogen.Service, name string) {
comment(g, name, " is a client for the ", service.Desc.FullName(), " service.")
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
Expand All @@ -112,18 +116,32 @@ func clientInterface(g *protogen.GeneratedFile, service *protogen.Service, name
}
g.Annotate(name, service.Location)
g.P("type ", name, " interface {")
for _, method := range unaryMethods(service) {
for _, method := range service.Methods {
g.Annotate(name+"."+method.GoName, method.Location)
g.P(method.Comments.Leading, clientSignature(g, method))
g.P(method.Comments.Leading, clientSignature(g, name, method))
}
g.P("}")
g.P()
}

func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
func clientSignature(g *protogen.GeneratedFile, cname string, method *protogen.Method) string {
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
deprecated(g)
}
if method.Desc.IsStreamingClient() {
// client and bidi streaming
return method.GoName + "(ctx " + g.QualifiedGoIdent(contextContext) +
", opts ..." + g.QualifiedGoIdent(rerpcPackage.Ident("CallOption")) + ") " +
"*" + clientStreamName(cname, method)
}
if method.Desc.IsStreamingServer() {
// server streaming
return method.GoName + "(ctx " + g.QualifiedGoIdent(contextContext) +
", req *" + g.QualifiedGoIdent(method.Input.GoIdent) +
", opts ..." + g.QualifiedGoIdent(rerpcPackage.Ident("CallOption")) + ") " +
"(*" + clientStreamName(cname, method) + ", error)"
}
// unary
return method.GoName + "(ctx " + g.QualifiedGoIdent(contextContext) +
", req *" + g.QualifiedGoIdent(method.Input.GoIdent) +
", opts ..." + g.QualifiedGoIdent(rerpcPackage.Ident("CallOption")) + ") " +
Expand Down Expand Up @@ -173,31 +191,67 @@ func clientImplementation(g *protogen.GeneratedFile, service *protogen.Service,
g.P()

// Client method implementations.
for _, method := range unaryMethods(service) {
clientMethod(g, service, method)
for _, method := range service.Methods {
clientMethod(g, service, name, method)
}
}

func clientMethod(g *protogen.GeneratedFile, service *protogen.Service, method *protogen.Method) {
func clientMethod(g *protogen.GeneratedFile, service *protogen.Service, cname string, method *protogen.Method) {
isStreamingClient := method.Desc.IsStreamingClient()
isStreamingServer := method.Desc.IsStreamingServer()
comment(g, method.GoName, " calls ", method.Desc.FullName(), ".",
" Call options passed here apply only to this call.")
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P("//")
deprecated(g)
}
g.P("func (c *", unexport(method.Parent.GoName), "ClientReRPC) ", clientSignature(g, method), "{")
g.P("func (c *", unexport(method.Parent.GoName), "ClientReRPC) ", clientSignature(g, cname, method), " {")
g.P("merged := c.mergeOptions(opts)")
g.P("ic := ", rerpcPackage.Ident("ConfiguredCallInterceptor"), "(merged...)")
g.P("ctx, call := ", rerpcPackage.Ident("NewCall"), "(")
g.P("ctx,")
g.P("c.doer,")
g.P(rerpcPackage.Ident("StreamTypeUnary"), ",")
if isStreamingClient && isStreamingServer {
g.P(rerpcPackage.Ident("StreamTypeBidirectional"), ",")
} else if isStreamingClient {
g.P(rerpcPackage.Ident("StreamTypeClient"), ",")
} else if isStreamingServer {
g.P(rerpcPackage.Ident("StreamTypeServer"), ",")
} else {
g.P(rerpcPackage.Ident("StreamTypeUnary"), ",")
}
g.P("c.baseURL,")
g.P(`"`, service.Desc.ParentFile().Package(), `", // protobuf package`)
g.P(`"`, service.Desc.Name(), `", // protobuf service`)
g.P(`"`, method.Desc.Name(), `", // protobuf method`)
g.P("merged...,")
g.P(")")

if isStreamingClient || isStreamingServer {
g.P("if ic != nil {")
g.P("call = ic.WrapStream(call)")
g.P("}")
g.P("stream := call(ctx)")
if !isStreamingClient && isStreamingServer {
// server streaming, we need to send the request.
g.P("if err := stream.Send(req); err != nil {")
g.P("_ = stream.CloseSend(err)")
g.P("_ = stream.CloseReceive()")
g.P("return nil, err")
g.P("}")
g.P("if err := stream.CloseSend(nil); err != nil {")
g.P("_ = stream.CloseReceive()")
g.P("return nil, err")
g.P("}")
g.P("return New", clientStreamName(cname, method), "(stream), nil")
} else {
g.P("return New", clientStreamName(cname, method), "(stream)")
}
g.P("}")
g.P()
return
}

g.P("wrapped := ", rerpcPackage.Ident("Func"), "(func(ctx ", contextContext, ", msg ", protoMessage, ") (", protoMessage, ", error) {")
g.P("stream := call(ctx)")
g.P("if err := stream.Send(req); err != nil {")
Expand All @@ -217,7 +271,7 @@ func clientMethod(g *protogen.GeneratedFile, service *protogen.Service, method *
g.P("return &res, stream.CloseReceive()")
g.P("})")

g.P("if ic := ", rerpcPackage.Ident("ConfiguredCallInterceptor"), "(merged...); ic != nil {")
g.P("if ic != nil {")
g.P("wrapped = ic.Wrap(wrapped)")
g.P("}")
g.P("res, err := wrapped(ctx, req)")
Expand All @@ -233,8 +287,74 @@ func clientMethod(g *protogen.GeneratedFile, service *protogen.Service, method *
g.P()
}

func clientStreams(g *protogen.GeneratedFile, service *protogen.Service, name string) {
for _, method := range service.Methods {
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
continue
}
streamName := clientStreamName(name, method)
isDeprecated := method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated()
comment(g, streamName, " is the client-side stream for the ", method.Desc.FullName(), " procedure.")
if isDeprecated {
g.P("//")
deprecated(g)
}
g.P("type ", streamName, " struct {")
g.P("stream ", rerpcPackage.Ident("Stream"))
g.P("}")
g.P()
g.P("func New", streamName, "(stream ", rerpcPackage.Ident("Stream"), ") *", streamName, " {")
g.P("return &", streamName, "{stream}")
g.P("}")
g.P()

if method.Desc.IsStreamingClient() {
g.P("func (s *", streamName, ") Send(msg *", method.Input.GoIdent, ") error {")
g.P("return s.stream.Send(msg)")
g.P("}")
g.P()
if method.Desc.IsStreamingServer() {
// Bidi: otherwise, we'll never know when they're done sending.
g.P("func (s *", streamName, ") CloseSend() error {")
g.P("return s.stream.CloseSend(nil)")
g.P("}")
g.P()
} else {
// Client-only streaming.
g.P("func (s *", streamName, ") ReceiveAndClose() (*", method.Output.GoIdent, ", error) {")
g.P("if err := s.stream.CloseSend(nil); err != nil {")
g.P("return nil, err")
g.P("}")
g.P("var res ", method.Output.GoIdent)
g.P("err := s.stream.Receive(&res)")
g.P("return &res, err")
g.P("}")
g.P()
}
}
if method.Desc.IsStreamingServer() {
g.P("func (s *", streamName, ") Receive() (*", method.Output.GoIdent, ", error) {")
g.P("var req ", method.Output.GoIdent)
g.P("if err := s.stream.Receive(&req); err != nil {")
g.P("return nil, err")
g.P("}")
g.P("return &req, nil")
g.P("}")
g.P()
closeName := "Close"
if method.Desc.IsStreamingClient() {
closeName = "CloseReceive"
}
g.P("func (s *", streamName, ") ", closeName, "() error {")
g.P("return s.stream.CloseReceive()")
g.P("}")
g.P()
}
}
}

func serverStreamName(sname string, method *protogen.Method) string {
return sname + "_" + method.GoName + "Server"
return sname + "_" + method.GoName
}

func serverInterface(g *protogen.GeneratedFile, service *protogen.Service, name string) {
Expand Down Expand Up @@ -487,14 +607,3 @@ func serverStreams(g *protogen.GeneratedFile, service *protogen.Service, sname s
}

func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }

func unaryMethods(service *protogen.Service) []*protogen.Method {
unary := make([]*protogen.Method, 0, len(service.Methods))
for _, m := range service.Methods {
if m.Desc.IsStreamingServer() || m.Desc.IsStreamingClient() {
continue
}
unary = append(unary, m)
}
return unary
}
6 changes: 4 additions & 2 deletions internal/crosstest/v1test/cross_rerpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit dc72cf6

Please sign in to comment.