From f6beb3c9d48b02bf1f8c4d39e54d754545246767 Mon Sep 17 00:00:00 2001 From: Mike Lemmon Date: Sat, 21 Apr 2018 18:28:43 -0700 Subject: [PATCH 1/7] Get download-style streaming rpc working * Use "EncodeMessage" instead of "Marshal" in serveProtobuf so message length is written properly * Flush stream response after each message if response writer implements http.Flusher * Add streaming "MakeHats" endpoint to example service --- example/cmd/client/main.go | 43 +++- example/cmd/server/main.go | 41 ++- example/service.pb.go | 41 ++- example/service.proto | 6 + example/service.twirp.go | 381 +++++++++++++++++++++++++++- example/service_pb2.py | 61 ++++- example/service_pb2_twirp.py | 9 + internal/twirptest/service.twirp.go | 8 +- protoc-gen-twirp/generator.go | 10 +- 9 files changed, 569 insertions(+), 31 deletions(-) diff --git a/example/cmd/client/main.go b/example/cmd/client/main.go index 041fe18b..a89c634c 100644 --- a/example/cmd/client/main.go +++ b/example/cmd/client/main.go @@ -16,6 +16,7 @@ package main import ( "context" "fmt" + "io" "log" "net/http" @@ -24,12 +25,14 @@ import ( ) func main() { - client := example.NewHaberdasherJSONClient("http://localhost:8080", &http.Client{}) + client := example.NewHaberdasherProtobufClient("http://localhost:8080", &http.Client{}) var ( - hat *example.Hat - err error + hat *example.Hat + hatStream example.HatStream + err error ) + for i := 0; i < 5; i++ { hat, err = client.MakeHat(context.Background(), &example.Size{Inches: 12}) if err != nil { @@ -43,6 +46,38 @@ func main() { // This was some fatal error! log.Fatal(err) } + break + } + fmt.Printf("Response from MakeHat:\n\t%+v\n", hat) + + // Ask for a stream of hats + for i := 0; i < 5; i++ { + hatStream, err = client.MakeHats( + context.Background(), + &example.MakeHatsReq{Inches: 12, Quantity: 7}, + ) + if err != nil { + if twerr, ok := err.(twirp.Error); ok { + if twerr.Meta("retryable") != "" { + // Log the error and go again. + log.Printf("got error %q, retrying", twerr) + continue + } + } + // This was some fatal error! + log.Fatal(err) + } + break + } + fmt.Printf("Response from MakeHats:\n") + for { + hat, err = hatStream.Next(context.Background()) + if err != nil { + if err == io.EOF { + break + } + log.Fatal(err) + } + fmt.Printf("\t%+v\n", hat) } - fmt.Printf("%+v", hat) } diff --git a/example/cmd/server/main.go b/example/cmd/server/main.go index 7d77302b..5115f895 100644 --- a/example/cmd/server/main.go +++ b/example/cmd/server/main.go @@ -15,27 +15,58 @@ package main import ( "context" + "io" "log" "math/rand" "net/http" "os" + "time" "github.com/twitchtv/twirp" "github.com/twitchtv/twirp/example" "github.com/twitchtv/twirp/hooks/statsd" ) +func newRandomHat(inches int32) *example.Hat { + return &example.Hat{ + Size: inches, + Color: []string{"white", "black", "brown", "red", "blue"}[rand.Intn(4)], + Name: []string{"bowler", "baseball cap", "top hat", "derby"}[rand.Intn(3)], + } +} + +type randomHatStream struct{ i, q int32 } + +func (hs *randomHatStream) Next(ctx context.Context) (*example.Hat, error) { + if hs.q == 0 { + return nil, io.EOF + } + hs.q-- + time.Sleep(300 * time.Millisecond) + return newRandomHat(hs.i), nil +} + +func (hs *randomHatStream) End(err error) { + // TODO: something? +} + type randomHaberdasher struct{} func (h *randomHaberdasher) MakeHat(ctx context.Context, size *example.Size) (*example.Hat, error) { if size.Inches <= 0 { return nil, twirp.InvalidArgumentError("Inches", "I can't make a hat that small!") } - return &example.Hat{ - Size: size.Inches, - Color: []string{"white", "black", "brown", "red", "blue"}[rand.Intn(4)], - Name: []string{"bowler", "baseball cap", "top hat", "derby"}[rand.Intn(3)], - }, nil + return newRandomHat(size.Inches), nil +} + +func (h *randomHaberdasher) MakeHats(ctx context.Context, req *example.MakeHatsReq) (example.HatStream, error) { + if req.Inches <= 0 { + return nil, twirp.InvalidArgumentError("Inches", "I can't make hats that small!") + } + if req.Quantity < 0 { + return nil, twirp.InvalidArgumentError("Quantity", "I can't make a negative quantity of hats!") + } + return &randomHatStream{i: req.Inches, q: req.Quantity}, nil } func main() { diff --git a/example/service.pb.go b/example/service.pb.go index 708de366..03992cc8 100644 --- a/example/service.pb.go +++ b/example/service.pb.go @@ -11,6 +11,7 @@ It is generated from these files: It has these top-level messages: Hat Size + MakeHatsReq */ package example @@ -84,15 +85,40 @@ func (m *Size) GetInches() int32 { return 0 } +type MakeHatsReq struct { + Inches int32 `protobuf:"varint,1,opt,name=inches" json:"inches,omitempty"` + Quantity int32 `protobuf:"varint,2,opt,name=quantity" json:"quantity,omitempty"` +} + +func (m *MakeHatsReq) Reset() { *m = MakeHatsReq{} } +func (m *MakeHatsReq) String() string { return proto.CompactTextString(m) } +func (*MakeHatsReq) ProtoMessage() {} +func (*MakeHatsReq) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } + +func (m *MakeHatsReq) GetInches() int32 { + if m != nil { + return m.Inches + } + return 0 +} + +func (m *MakeHatsReq) GetQuantity() int32 { + if m != nil { + return m.Quantity + } + return 0 +} + func init() { proto.RegisterType((*Hat)(nil), "twitch.twirp.example.Hat") proto.RegisterType((*Size)(nil), "twitch.twirp.example.Size") + proto.RegisterType((*MakeHatsReq)(nil), "twitch.twirp.example.MakeHatsReq") } func init() { proto.RegisterFile("service.proto", fileDescriptor0) } var fileDescriptor0 = []byte{ - // 185 bytes of a gzipped FileDescriptorProto + // 234 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x4e, 0x2d, 0x2a, 0xcb, 0x4c, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x29, 0x29, 0xcf, 0x2c, 0x49, 0xce, 0xd0, 0x2b, 0x29, 0xcf, 0x2c, 0x2a, 0xd0, 0x4b, 0xad, 0x48, 0xcc, 0x2d, 0xc8, 0x49, 0x55, @@ -100,9 +126,12 @@ var fileDescriptor0 = []byte{ 0x54, 0x60, 0xd4, 0x60, 0x0d, 0x02, 0xb3, 0x85, 0x44, 0xb8, 0x58, 0x93, 0xf3, 0x73, 0xf2, 0x8b, 0x24, 0x98, 0x14, 0x18, 0x35, 0x38, 0x83, 0x20, 0x1c, 0x90, 0xca, 0xbc, 0xc4, 0xdc, 0x54, 0x09, 0x66, 0xb0, 0x20, 0x98, 0xad, 0x24, 0xc7, 0xc5, 0x12, 0x0c, 0xd2, 0x21, 0xc6, 0xc5, 0x96, 0x99, - 0x97, 0x9c, 0x91, 0x5a, 0x0c, 0x35, 0x07, 0xca, 0x33, 0xf2, 0xe7, 0xe2, 0xf6, 0x48, 0x4c, 0x4a, - 0x2d, 0x4a, 0x49, 0x2c, 0xce, 0x48, 0x2d, 0x12, 0x72, 0xe0, 0x62, 0xf7, 0x4d, 0xcc, 0x4e, 0x05, - 0xd9, 0x2b, 0xa5, 0x87, 0xcd, 0x55, 0x7a, 0x20, 0xd3, 0xa4, 0x24, 0xb1, 0xcb, 0x79, 0x24, 0x96, - 0x38, 0x71, 0x46, 0xb1, 0x43, 0xb9, 0x49, 0x6c, 0x60, 0xdf, 0x19, 0x03, 0x02, 0x00, 0x00, 0xff, - 0xff, 0x8a, 0xbb, 0x3b, 0x6a, 0xee, 0x00, 0x00, 0x00, + 0x97, 0x9c, 0x91, 0x5a, 0x0c, 0x35, 0x07, 0xca, 0x53, 0x72, 0xe4, 0xe2, 0xf6, 0x4d, 0xcc, 0x4e, + 0xf5, 0x48, 0x2c, 0x29, 0x0e, 0x4a, 0x2d, 0xc4, 0xa5, 0x4c, 0x48, 0x8a, 0x8b, 0xa3, 0xb0, 0x34, + 0x31, 0xaf, 0x24, 0xb3, 0xa4, 0x12, 0x6c, 0x27, 0x6b, 0x10, 0x9c, 0x6f, 0x34, 0x9b, 0x91, 0x8b, + 0xdb, 0x23, 0x31, 0x29, 0xb5, 0x28, 0x25, 0xb1, 0x38, 0x23, 0xb5, 0x48, 0xc8, 0x81, 0x8b, 0x1d, + 0x6a, 0xa4, 0x90, 0x94, 0x1e, 0x36, 0x9f, 0xe9, 0x81, 0x5c, 0x24, 0x25, 0x89, 0x5d, 0x0e, 0xa4, + 0xcd, 0x8b, 0x8b, 0x03, 0xe6, 0x28, 0x21, 0x45, 0xec, 0xca, 0x90, 0x1c, 0x8d, 0xc7, 0x24, 0x03, + 0x46, 0x27, 0xce, 0x28, 0x76, 0xa8, 0x40, 0x12, 0x1b, 0x38, 0xb4, 0x8d, 0x01, 0x01, 0x00, 0x00, + 0xff, 0xff, 0x78, 0xf6, 0xd2, 0xc4, 0x7e, 0x01, 0x00, 0x00, } diff --git a/example/service.proto b/example/service.proto index 875f55ab..f8207ee2 100644 --- a/example/service.proto +++ b/example/service.proto @@ -22,8 +22,14 @@ message Size { int32 inches = 1; } +message MakeHatsReq { + int32 inches = 1; + int32 quantity = 2; +} + // A Haberdasher makes hats for clients. service Haberdasher { // MakeHat produces a hat of mysterious, randomly-selected color! rpc MakeHat(Size) returns (Hat); + rpc MakeHats(MakeHatsReq) returns (stream Hat); } diff --git a/example/service.twirp.go b/example/service.twirp.go index 6ba02a14..7e6e1063 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -27,6 +27,8 @@ import io "io" import strconv "strconv" import json "encoding/json" import url "net/url" +import bufio "bufio" +import binary "encoding/binary" // ===================== // Haberdasher Interface @@ -36,6 +38,8 @@ import url "net/url" type Haberdasher interface { // MakeHat produces a hat of mysterious, randomly-selected color! MakeHat(ctx context.Context, in *Size) (*Hat, error) + + MakeHats(ctx context.Context, in *MakeHatsReq) (HatStream, error) } // =========================== @@ -44,15 +48,16 @@ type Haberdasher interface { type haberdasherProtobufClient struct { client HTTPClient - urls [1]string + urls [2]string } // NewHaberdasherProtobufClient creates a Protobuf client that implements the Haberdasher interface. // It communicates using Protobuf and can be configured with a custom HTTPClient. func NewHaberdasherProtobufClient(addr string, client HTTPClient) Haberdasher { prefix := urlBase(addr) + HaberdasherPathPrefix - urls := [1]string{ + urls := [2]string{ prefix + "MakeHat", + prefix + "MakeHats", } if httpClient, ok := client.(*http.Client); ok { return &haberdasherProtobufClient{ @@ -75,21 +80,53 @@ func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat return out, err } +func (c *haberdasherProtobufClient) MakeHats(ctx context.Context, in *MakeHatsReq) (HatStream, error) { + ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") + ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") + ctx = ctxsetters.WithMethodName(ctx, "MakeHats") + reqBodyBytes, err := proto.Marshal(in) + if err != nil { + return nil, clientError("failed to marshal proto request", err) + } + reqBody := bytes.NewBuffer(reqBodyBytes) + if err = ctx.Err(); err != nil { + return nil, clientError("aborted because context was done", err) + } + + req, err := newRequest(ctx, c.urls[1], reqBody, "application/protobuf") + if err != nil { + return nil, clientError("could not build request", err) + } + resp, err := c.client.Do(req) + if err != nil { + return nil, clientError("failed to do request", err) + } + + return &protoHatStreamReader{ + prs: protoStreamReader{ + r: bufio.NewReader(resp.Body), + c: resp.Body, + maxSize: 1 << 21, // 1GB + }, + }, nil +} + // ======================= // Haberdasher JSON Client // ======================= type haberdasherJSONClient struct { client HTTPClient - urls [1]string + urls [2]string } // NewHaberdasherJSONClient creates a JSON client that implements the Haberdasher interface. // It communicates using JSON and can be configured with a custom HTTPClient. func NewHaberdasherJSONClient(addr string, client HTTPClient) Haberdasher { prefix := urlBase(addr) + HaberdasherPathPrefix - urls := [1]string{ + urls := [2]string{ prefix + "MakeHat", + prefix + "MakeHats", } if httpClient, ok := client.(*http.Client); ok { return &haberdasherJSONClient{ @@ -112,6 +149,38 @@ func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, er return out, err } +func (c *haberdasherJSONClient) MakeHats(ctx context.Context, in *MakeHatsReq) (HatStream, error) { + ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") + ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") + ctx = ctxsetters.WithMethodName(ctx, "MakeHats") + reqBodyBytes, err := proto.Marshal(in) + if err != nil { + return nil, clientError("failed to marshal proto request", err) + } + reqBody := bytes.NewBuffer(reqBodyBytes) + if err = ctx.Err(); err != nil { + return nil, clientError("aborted because context was done", err) + } + + req, err := newRequest(ctx, c.urls[1], reqBody, "application/json") + if err != nil { + return nil, clientError("could not build request", err) + } + resp, err := c.client.Do(req) + if err != nil { + return nil, clientError("failed to do request", err) + } + + jrs, err := newJSONStreamReader(resp.Body) + if err != nil { + return nil, err + } + return &jsonHatStreamReader{ + jrs: jrs, + c: resp.Body, + }, nil +} + // ========================== // Haberdasher Server Handler // ========================== @@ -163,6 +232,9 @@ func (s *haberdasherServer) ServeHTTP(resp http.ResponseWriter, req *http.Reques case "/twirp/twitch.twirp.example.Haberdasher/MakeHat": s.serveMakeHat(ctx, resp, req) return + case "/twirp/twitch.twirp.example.Haberdasher/MakeHats": + s.serveMakeHats(ctx, resp, req) + return default: msg := fmt.Sprintf("no handler for path %q", req.URL.Path) err = badRouteError(msg, req.Method, req.URL.Path) @@ -315,6 +387,124 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. callResponseSent(ctx, s.hooks) } +func (s *haberdasherServer) serveMakeHats(ctx context.Context, resp http.ResponseWriter, req *http.Request) { + header := req.Header.Get("Content-Type") + i := strings.Index(header, ";") + if i == -1 { + i = len(header) + } + switch strings.TrimSpace(strings.ToLower(header[:i])) { + case "application/json": + s.serveMakeHatsJSON(ctx, resp, req) + case "application/protobuf": + s.serveMakeHatsProtobuf(ctx, resp, req) + default: + msg := fmt.Sprintf("unexpected Content-Type: %q", req.Header.Get("Content-Type")) + twerr := badRouteError(msg, req.Method, req.URL.Path) + s.writeError(ctx, resp, twerr) + } +} + +func (s *haberdasherServer) serveMakeHatsJSON(ctx context.Context, resp http.ResponseWriter, req *http.Request) { +} + +func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { + var err error + ctx = ctxsetters.WithMethodName(ctx, "MakeHats") + ctx, err = callRequestRouted(ctx, s.hooks) + if err != nil { + s.writeError(ctx, resp, err) + return + } + + resp.Header().Set("Content-Type", "application/protobuf") + buf, err := ioutil.ReadAll(req.Body) + if err != nil { + err = wrapErr(err, "failed to read request body") + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + return + } + reqContent := new(MakeHatsReq) + if err = proto.Unmarshal(buf, reqContent); err != nil { + err = wrapErr(err, "failed to parse request proto") + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + return + } + + // Call service method + var respStream HatStream + func() { + defer func() { + // In case of a panic, serve a 500 error and then panic. + if r := recover(); r != nil { + s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + panic(r) + } + }() + respStream, err = s.MakeHats(ctx, reqContent) + }() + + if err != nil { + s.writeError(ctx, resp, err) + return + } + + if respStream == nil { + s.writeError(ctx, resp, twirp.InternalError("received a nil MakeHatsReq and nil error while calling MakeHats. nil responses are not supported")) + return + } + + respFlusher, canFlush := resp.(http.Flusher) + + ctx = callResponsePrepared(ctx, s.hooks) + + messages := proto.NewBuffer(nil) + + trailer := proto.NewBuffer(nil) + _ = trailer.EncodeVarint((2 << 3) | 2) // field tag + for { + msg, err := respStream.Next(ctx) + if err != nil { + // TODO: figure out trailers' proto encoding beyond just a string + if err == io.EOF { + _ = trailer.EncodeStringBytes("OK") + } else { + _ = trailer.EncodeStringBytes(err.Error()) + } + break + } + + messages.Reset() + _ = messages.EncodeVarint((1 << 3) | 2) // field tag + err = messages.EncodeMessage(msg) + if err != nil { + err = wrapErr(err, "failed to marshal proto message") + respStream.End(err) + break + } + + _, err = resp.Write(messages.Bytes()) + if err != nil { + err = wrapErr(err, "failed to send proto message") + respStream.End(err) + break + } + + if canFlush { + respFlusher.Flush() + } + + // TODO: Call a hook that we sent a message in a stream? + } + + _, err = resp.Write(trailer.Bytes()) + if err != nil { + // TODO: call error hook? + err = wrapErr(err, "failed to write trailer") + respStream.End(err) + } +} + func (s *haberdasherServer) ServiceDescriptor() ([]byte, int) { return twirpFileDescriptor0, 0 } @@ -323,6 +513,43 @@ func (s *haberdasherServer) ProtocGenTwirpVersion() string { return "v5.3.0" } +// HatStream represents a stream of Hat messages. +type HatStream interface { + Next(context.Context) (*Hat, error) + End(error) +} + +type protoHatStreamReader struct { + prs protoStreamReader +} + +func (r protoHatStreamReader) Next(context.Context) (*Hat, error) { + out := new(Hat) + err := r.prs.Read(out) + if err != nil { + return nil, err + } + return out, nil +} + +func (r protoHatStreamReader) End(error) { _ = r.prs.c.Close() } + +type jsonHatStreamReader struct { + jrs *jsonStreamReader + c io.Closer +} + +func (r jsonHatStreamReader) Next(context.Context) (*Hat, error) { + out := new(Hat) + err := r.jrs.Read(out) + if err != nil { + return nil, err + } + return out, nil +} + +func (r jsonHatStreamReader) End(error) { _ = r.c.Close() } + // ===== // Utils // ===== @@ -747,8 +974,139 @@ func callError(ctx context.Context, h *twirp.ServerHooks, err twirp.Error) conte return h.Error(ctx, err) } +type protoStreamReader struct { + r *bufio.Reader + c io.Closer + + maxSize int +} + +func (r protoStreamReader) Read(msg proto.Message) error { + // Get next field tag. + tag, err := binary.ReadUvarint(r.r) + if err != nil { + return err + } + + const ( + msgTag = (1 << 3) | 2 + trailerTag = (2 << 3) | 2 + ) + + if tag == trailerTag { + _ = r.c.Close() + return io.EOF + } + + if tag != msgTag { + return fmt.Errorf("invalid field tag: %v", tag) + } + + // This is a real message. How long is it? + l, err := binary.ReadUvarint(r.r) + if err != nil { + return err + } + if int(l) < 0 || int(l) > r.maxSize { + return io.ErrShortBuffer + } + buf := make([]byte, int(l)) + + // Go ahead and read a message. + _, err = io.ReadFull(r.r, buf) + if err != nil { + return err + } + + err = proto.Unmarshal(buf, msg) + if err != nil { + return err + } + return nil +} + +type jsonStreamReader struct { + dec *json.Decoder + unmarshaler *jsonpb.Unmarshaler + messageStreamDone bool +} + +func newJSONStreamReader(r io.Reader) (*jsonStreamReader, error) { + // stream should start with {"messages":[ + dec := json.NewDecoder(r) + t, err := dec.Token() + if err != nil { + return nil, err + } + delim, ok := t.(json.Delim) + if !ok || delim != '{' { + return nil, fmt.Errorf("missing leading { in JSON stream, found %q", t) + } + + t, err = dec.Token() + if err != nil { + return nil, err + } + key, ok := t.(string) + if !ok || key != "messages" { + return nil, fmt.Errorf("missing \"messages\" key in JSON stream, found %q", t) + } + + t, err = dec.Token() + if err != nil { + return nil, err + } + delim, ok = t.(json.Delim) + if !ok || delim != '[' { + return nil, fmt.Errorf("missing [ to open messages array in JSON stream, found %q", t) + } + + return &jsonStreamReader{ + dec: dec, + unmarshaler: &jsonpb.Unmarshaler{AllowUnknownFields: true}, + }, nil +} + +func (r *jsonStreamReader) Read(msg proto.Message) error { + if !r.messageStreamDone && r.dec.More() { + return r.unmarshaler.UnmarshalNext(r.dec, msg) + } + + // else, we hit the end of the message stream. finish up the array, and then read the trailer. + r.messageStreamDone = true + t, err := r.dec.Token() + if err != nil { + return err + } + delim, ok := t.(json.Delim) + if !ok || delim != ']' { + return fmt.Errorf("missing end of message array in JSON stream, found %q", t) + } + + t, err = r.dec.Token() + if err != nil { + return err + } + key, ok := t.(string) + if !ok || key != "trailer" { + return fmt.Errorf("missing trailer after messages in JSON stream, found %q", t) + } + + var tj twerrJSON + err = r.dec.Decode(&tj) + if err != nil { + return err + } + + if tj.Code == "stream_complete" { + return io.EOF + } + + return tj.toTwirpError() +} + var twirpFileDescriptor0 = []byte{ - // 185 bytes of a gzipped FileDescriptorProto + // 234 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x4e, 0x2d, 0x2a, 0xcb, 0x4c, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x29, 0x29, 0xcf, 0x2c, 0x49, 0xce, 0xd0, 0x2b, 0x29, 0xcf, 0x2c, 0x2a, 0xd0, 0x4b, 0xad, 0x48, 0xcc, 0x2d, 0xc8, 0x49, 0x55, @@ -756,9 +1114,12 @@ var twirpFileDescriptor0 = []byte{ 0x54, 0x60, 0xd4, 0x60, 0x0d, 0x02, 0xb3, 0x85, 0x44, 0xb8, 0x58, 0x93, 0xf3, 0x73, 0xf2, 0x8b, 0x24, 0x98, 0x14, 0x18, 0x35, 0x38, 0x83, 0x20, 0x1c, 0x90, 0xca, 0xbc, 0xc4, 0xdc, 0x54, 0x09, 0x66, 0xb0, 0x20, 0x98, 0xad, 0x24, 0xc7, 0xc5, 0x12, 0x0c, 0xd2, 0x21, 0xc6, 0xc5, 0x96, 0x99, - 0x97, 0x9c, 0x91, 0x5a, 0x0c, 0x35, 0x07, 0xca, 0x33, 0xf2, 0xe7, 0xe2, 0xf6, 0x48, 0x4c, 0x4a, - 0x2d, 0x4a, 0x49, 0x2c, 0xce, 0x48, 0x2d, 0x12, 0x72, 0xe0, 0x62, 0xf7, 0x4d, 0xcc, 0x4e, 0x05, - 0xd9, 0x2b, 0xa5, 0x87, 0xcd, 0x55, 0x7a, 0x20, 0xd3, 0xa4, 0x24, 0xb1, 0xcb, 0x79, 0x24, 0x96, - 0x38, 0x71, 0x46, 0xb1, 0x43, 0xb9, 0x49, 0x6c, 0x60, 0xdf, 0x19, 0x03, 0x02, 0x00, 0x00, 0xff, - 0xff, 0x8a, 0xbb, 0x3b, 0x6a, 0xee, 0x00, 0x00, 0x00, + 0x97, 0x9c, 0x91, 0x5a, 0x0c, 0x35, 0x07, 0xca, 0x53, 0x72, 0xe4, 0xe2, 0xf6, 0x4d, 0xcc, 0x4e, + 0xf5, 0x48, 0x2c, 0x29, 0x0e, 0x4a, 0x2d, 0xc4, 0xa5, 0x4c, 0x48, 0x8a, 0x8b, 0xa3, 0xb0, 0x34, + 0x31, 0xaf, 0x24, 0xb3, 0xa4, 0x12, 0x6c, 0x27, 0x6b, 0x10, 0x9c, 0x6f, 0x34, 0x9b, 0x91, 0x8b, + 0xdb, 0x23, 0x31, 0x29, 0xb5, 0x28, 0x25, 0xb1, 0x38, 0x23, 0xb5, 0x48, 0xc8, 0x81, 0x8b, 0x1d, + 0x6a, 0xa4, 0x90, 0x94, 0x1e, 0x36, 0x9f, 0xe9, 0x81, 0x5c, 0x24, 0x25, 0x89, 0x5d, 0x0e, 0xa4, + 0xcd, 0x8b, 0x8b, 0x03, 0xe6, 0x28, 0x21, 0x45, 0xec, 0xca, 0x90, 0x1c, 0x8d, 0xc7, 0x24, 0x03, + 0x46, 0x27, 0xce, 0x28, 0x76, 0xa8, 0x40, 0x12, 0x1b, 0x38, 0xb4, 0x8d, 0x01, 0x01, 0x00, 0x00, + 0xff, 0xff, 0x78, 0xf6, 0xd2, 0xc4, 0x7e, 0x01, 0x00, 0x00, } diff --git a/example/service_pb2.py b/example/service_pb2.py index a6f0488b..d144b172 100644 --- a/example/service_pb2.py +++ b/example/service_pb2.py @@ -19,7 +19,7 @@ name='service.proto', package='twitch.twirp.example', syntax='proto3', - serialized_pb=_b('\n\rservice.proto\x12\x14twitch.twirp.example\"0\n\x03Hat\x12\x0c\n\x04size\x18\x01 \x01(\x05\x12\r\n\x05\x63olor\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\"\x16\n\x04Size\x12\x0e\n\x06inches\x18\x01 \x01(\x05\x32O\n\x0bHaberdasher\x12@\n\x07MakeHat\x12\x1a.twitch.twirp.example.Size\x1a\x19.twitch.twirp.example.HatB\tZ\x07\x65xampleb\x06proto3') + serialized_pb=_b('\n\rservice.proto\x12\x14twitch.twirp.example\"0\n\x03Hat\x12\x0c\n\x04size\x18\x01 \x01(\x05\x12\r\n\x05\x63olor\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\"\x16\n\x04Size\x12\x0e\n\x06inches\x18\x01 \x01(\x05\"/\n\x0bMakeHatsReq\x12\x0e\n\x06inches\x18\x01 \x01(\x05\x12\x10\n\x08quantity\x18\x02 \x01(\x05\x32\x9b\x01\n\x0bHaberdasher\x12@\n\x07MakeHat\x12\x1a.twitch.twirp.example.Size\x1a\x19.twitch.twirp.example.Hat\x12J\n\x08MakeHats\x12!.twitch.twirp.example.MakeHatsReq\x1a\x19.twitch.twirp.example.Hat0\x01\x42\tZ\x07\x65xampleb\x06proto3') ) @@ -100,8 +100,47 @@ serialized_end=111, ) + +_MAKEHATSREQ = _descriptor.Descriptor( + name='MakeHatsReq', + full_name='twitch.twirp.example.MakeHatsReq', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='inches', full_name='twitch.twirp.example.MakeHatsReq.inches', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='quantity', full_name='twitch.twirp.example.MakeHatsReq.quantity', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=113, + serialized_end=160, +) + DESCRIPTOR.message_types_by_name['Hat'] = _HAT DESCRIPTOR.message_types_by_name['Size'] = _SIZE +DESCRIPTOR.message_types_by_name['MakeHatsReq'] = _MAKEHATSREQ _sym_db.RegisterFileDescriptor(DESCRIPTOR) Hat = _reflection.GeneratedProtocolMessageType('Hat', (_message.Message,), dict( @@ -118,6 +157,13 @@ )) _sym_db.RegisterMessage(Size) +MakeHatsReq = _reflection.GeneratedProtocolMessageType('MakeHatsReq', (_message.Message,), dict( + DESCRIPTOR = _MAKEHATSREQ, + __module__ = 'service_pb2' + # @@protoc_insertion_point(class_scope:twitch.twirp.example.MakeHatsReq) + )) +_sym_db.RegisterMessage(MakeHatsReq) + DESCRIPTOR.has_options = True DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('Z\007example')) @@ -128,8 +174,8 @@ file=DESCRIPTOR, index=0, options=None, - serialized_start=113, - serialized_end=192, + serialized_start=163, + serialized_end=318, methods=[ _descriptor.MethodDescriptor( name='MakeHat', @@ -140,6 +186,15 @@ output_type=_HAT, options=None, ), + _descriptor.MethodDescriptor( + name='MakeHats', + full_name='twitch.twirp.example.Haberdasher.MakeHats', + index=1, + containing_service=None, + input_type=_MAKEHATSREQ, + output_type=_HAT, + options=None, + ), ]) _sym_db.RegisterServiceDescriptor(_HABERDASHER) diff --git a/example/service_pb2_twirp.py b/example/service_pb2_twirp.py index 1548accf..b32cbaf6 100644 --- a/example/service_pb2_twirp.py +++ b/example/service_pb2_twirp.py @@ -78,3 +78,12 @@ def make_hat(self, size): resp_str = self.__make_request(body=body, full_method=full_method) return deserialize(resp_str) + def make_hats(self, make_hats_req): + serialize = _sym_db.GetSymbol("twitch.twirp.example.MakeHatsReq").SerializeToString + deserialize = _sym_db.GetSymbol("twitch.twirp.example.Hat").FromString + + full_method = "/{}/{}".format(self.__service_name, "MakeHats") + body = serialize(make_hats_req) + resp_str = self.__make_request(body=body, full_method=full_method) + return deserialize(resp_str) + diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index 6c7a992a..1b4b8fb8 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -790,6 +790,8 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re return } + respFlusher, canFlush := resp.(http.Flusher) + ctx = callResponsePrepared(ctx, s.hooks) messages := proto.NewBuffer(nil) @@ -810,7 +812,7 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re messages.Reset() _ = messages.EncodeVarint((1 << 3) | 2) // field tag - err = messages.Marshal(msg) + err = messages.EncodeMessage(msg) if err != nil { err = wrapErr(err, "failed to marshal proto message") respStream.End(err) @@ -824,6 +826,10 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re break } + if canFlush { + respFlusher.Flush() + } + // TODO: Call a hook that we sent a message in a stream? } diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index f1e64c87..dea2b1ef 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -1390,7 +1390,7 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` panic(r)`) t.P(` }`) t.P(` }()`) - t.P(` respStream, err = s.Download(ctx, reqContent)`) + t.P(` respStream, err = s.`, methName, `(ctx, reqContent)`) t.P(` }()`) t.P(``) t.P(` if err != nil {`) @@ -1403,6 +1403,8 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` return`) t.P(` }`) t.P(``) + t.P(` respFlusher, canFlush := resp.(http.Flusher)`) + t.P(``) t.P(` ctx = callResponsePrepared(ctx, s.hooks)`) t.P(``) t.P(` messages := `, t.pkgs["proto"], `.NewBuffer(nil)`) @@ -1423,7 +1425,7 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(``) t.P(` messages.Reset()`) t.P(` _ = messages.EncodeVarint((1 << 3) | 2) // field tag`) - t.P(` err = messages.Marshal(msg)`) + t.P(` err = messages.EncodeMessage(msg)`) t.P(` if err != nil {`) t.P(` err = wrapErr(err, "failed to marshal proto message")`) t.P(` respStream.End(err)`) @@ -1437,6 +1439,10 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` break`) t.P(` }`) t.P(``) + t.P(` if canFlush {`) + t.P(` respFlusher.Flush()`) + t.P(` }`) + t.P(``) t.P(` // TODO: Call a hook that we sent a message in a stream?`) t.P(` }`) t.P(``) From 7ce909891615ba8233f6a3eb61688e8db48417f6 Mon Sep 17 00:00:00 2001 From: Mike Lemmon Date: Sun, 20 May 2018 23:35:29 -0700 Subject: [PATCH 2/7] Use json-encoded twirp error for streaming trailers * Also, add tests for example server * JSON clients currently fail to stream (rpc call gets EOF err instead of RespStream) --- .../clientcompat/clientcompat.twirp.go | 20 +- example/cmd/client/main.go | 55 +++--- example/cmd/server/main.go | 34 ++-- example/cmd/server/main_test.go | 186 ++++++++++++++++++ example/cmd/server/random_hat_stream.go | 85 ++++++++ example/service.twirp.go | 113 +++++++---- .../twirptest/gogo_compat/service.twirp.go | 10 +- .../twirptest/importable/importable.twirp.go | 10 +- internal/twirptest/importer/importer.twirp.go | 57 ++++-- .../twirptest/multiple/multiple1.twirp.go | 10 +- .../twirptest/multiple/multiple2.twirp.go | 20 +- .../no_package_name/no_package_name.twirp.go | 10 +- .../no_package_name_importer.twirp.go | 10 +- internal/twirptest/proto/proto.twirp.go | 10 +- internal/twirptest/service.twirp.go | 131 ++++++++---- protoc-gen-twirp/generator.go | 136 ++++++++----- 16 files changed, 688 insertions(+), 209 deletions(-) create mode 100644 example/cmd/server/main_test.go create mode 100644 example/cmd/server/random_hat_stream.go diff --git a/clientcompat/internal/clientcompat/clientcompat.twirp.go b/clientcompat/internal/clientcompat/clientcompat.twirp.go index 2e5b391a..c956416a 100644 --- a/clientcompat/internal/clientcompat/clientcompat.twirp.go +++ b/clientcompat/internal/clientcompat/clientcompat.twirp.go @@ -276,10 +276,14 @@ func (s *compatServiceServer) serveMethodJSON(ctx context.Context, resp http.Res func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Method") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -287,13 +291,13 @@ func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Req) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } @@ -420,10 +424,14 @@ func (s *compatServiceServer) serveNoopMethodJSON(ctx context.Context, resp http func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "NoopMethod") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -431,13 +439,13 @@ func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Empty) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } diff --git a/example/cmd/client/main.go b/example/cmd/client/main.go index a89c634c..e1a3cf48 100644 --- a/example/cmd/client/main.go +++ b/example/cmd/client/main.go @@ -15,10 +15,10 @@ package main import ( "context" - "fmt" "io" "log" "net/http" + "time" "github.com/twitchtv/twirp" "github.com/twitchtv/twirp/example" @@ -33,6 +33,7 @@ func main() { err error ) + // Call the MakeHat rpc for i := 0; i < 5; i++ { hat, err = client.MakeHat(context.Background(), &example.Size{Inches: 12}) if err != nil { @@ -48,36 +49,46 @@ func main() { } break } - fmt.Printf("Response from MakeHat:\n\t%+v\n", hat) + log.Println(`Response from MakeHat:`) + log.Printf("\t%+v\n", hat) - // Ask for a stream of hats - for i := 0; i < 5; i++ { - hatStream, err = client.MakeHats( - context.Background(), - &example.MakeHatsReq{Inches: 12, Quantity: 7}, + // Call the MakeHats streaming rpc + reqSentAt := time.Now() + quantity := int32(300000) + hatStream, err = client.MakeHats( + context.Background(), + &example.MakeHatsReq{Inches: 12, Quantity: quantity}, + ) + if err != nil { + log.Fatal(err) + } + log.Printf("Response from MakeHats:\n") + ii := 1 + printResults := func() { + took := time.Now().Sub(reqSentAt) + log.Printf( + "Received %.1f kHats per second (%d hats in %f seconds)\n", + float64(ii-1)/took.Seconds()/1000, + ii-1, took.Seconds(), ) - if err != nil { - if twerr, ok := err.(twirp.Error); ok { - if twerr.Meta("retryable") != "" { - // Log the error and go again. - log.Printf("got error %q, retrying", twerr) - continue - } - } - // This was some fatal error! - log.Fatal(err) - } - break } - fmt.Printf("Response from MakeHats:\n") - for { + defer printResults() + for ; true; ii++ { // Receive all the hats hat, err = hatStream.Next(context.Background()) if err != nil { if err == io.EOF { break } + printResults() log.Fatal(err) } - fmt.Printf("\t%+v\n", hat) + if ii%50000 == 0 { + log.Printf( + "\t[%4.1f khps] %6d %+v\n", + float64(ii)/time.Now().Sub(reqSentAt).Seconds()/1000, + ii, hat, + ) + } } + } diff --git a/example/cmd/server/main.go b/example/cmd/server/main.go index 5115f895..2e522387 100644 --- a/example/cmd/server/main.go +++ b/example/cmd/server/main.go @@ -15,12 +15,10 @@ package main import ( "context" - "io" "log" "math/rand" "net/http" "os" - "time" "github.com/twitchtv/twirp" "github.com/twitchtv/twirp/example" @@ -30,43 +28,33 @@ import ( func newRandomHat(inches int32) *example.Hat { return &example.Hat{ Size: inches, - Color: []string{"white", "black", "brown", "red", "blue"}[rand.Intn(4)], - Name: []string{"bowler", "baseball cap", "top hat", "derby"}[rand.Intn(3)], + Color: []string{"white", "black", "brown", "red", "blue"}[rand.Intn(5)], + Name: []string{"bowler", "baseball cap", "top hat", "derby"}[rand.Intn(4)], } } -type randomHatStream struct{ i, q int32 } +type randomHaberdasher struct{ quiet bool } -func (hs *randomHatStream) Next(ctx context.Context) (*example.Hat, error) { - if hs.q == 0 { - return nil, io.EOF - } - hs.q-- - time.Sleep(300 * time.Millisecond) - return newRandomHat(hs.i), nil -} - -func (hs *randomHatStream) End(err error) { - // TODO: something? -} - -type randomHaberdasher struct{} +var ( + errTooSmall = twirp.InvalidArgumentError("Inches", "I can't make hats that small!") + errNegativeQuantity = twirp.InvalidArgumentError("Quantity", "I can't make a negative quantity of hats!") +) func (h *randomHaberdasher) MakeHat(ctx context.Context, size *example.Size) (*example.Hat, error) { if size.Inches <= 0 { - return nil, twirp.InvalidArgumentError("Inches", "I can't make a hat that small!") + return nil, errTooSmall } return newRandomHat(size.Inches), nil } func (h *randomHaberdasher) MakeHats(ctx context.Context, req *example.MakeHatsReq) (example.HatStream, error) { if req.Inches <= 0 { - return nil, twirp.InvalidArgumentError("Inches", "I can't make hats that small!") + return nil, errTooSmall } if req.Quantity < 0 { - return nil, twirp.InvalidArgumentError("Quantity", "I can't make a negative quantity of hats!") + return nil, errNegativeQuantity } - return &randomHatStream{i: req.Inches, q: req.Quantity}, nil + return newRandomHatStream(req.Inches, req.Quantity, h.quiet), nil } func main() { diff --git a/example/cmd/server/main_test.go b/example/cmd/server/main_test.go new file mode 100644 index 00000000..acbfb415 --- /dev/null +++ b/example/cmd/server/main_test.go @@ -0,0 +1,186 @@ +package main + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "os" + "testing" + "time" + + "github.com/twitchtv/twirp/example" +) + +func TestMain(m *testing.M) { + go runServer() + os.Exit(m.Run()) +} + +func runServer() { + server := example.NewHaberdasherServer(&randomHaberdasher{quiet: true}, nil) + log.Fatal(http.ListenAndServe(":8080", server)) +} + +func newProtoClient() example.Haberdasher { + return example.NewHaberdasherProtobufClient("http://localhost:8080", &http.Client{}) +} + +func newJSONClient() example.Haberdasher { + return example.NewHaberdasherJSONClient("http://localhost:8080", &http.Client{}) +} + +type client struct { + name string + c example.Haberdasher +} + +func clients() []client { + return []client{ + {`Proto`, newProtoClient()}, + {`JSON`, newJSONClient()}, + } +} + +func compareErrors(got, expected error) error { + if got.Error() == expected.Error() { + return nil + } + return fmt.Errorf(`Expected err to be %#v, got %#v`, expected, got) +} + +func TestInvalidMakeHatsRequests(t *testing.T) { + type testReq struct { + name string + req *example.MakeHatsReq + expected error + } + testReqs := []testReq{ + { + name: `TooSmall`, + req: &example.MakeHatsReq{Inches: -5}, + expected: errTooSmall, + }, + { + name: `NegativeQuantity`, + req: &example.MakeHatsReq{Inches: 8, Quantity: -5}, + expected: errNegativeQuantity, + }, + } + + for _, cc := range clients() { + for _, re := range testReqs { + t.Run(re.name+cc.name, func(t *testing.T) { + hatStream, err := cc.c.MakeHats(context.Background(), re.req) + if err != nil { + t.Fatalf(`MakeHats request failed: %#v`, err) + } + _, err = hatStream.Next(context.Background()) + err = compareErrors(err, re.expected) + if err != nil { + t.Fatal(err) + } + }) + } + } +} + +func TestMakeHatsPerf(t *testing.T) { + type testReq struct { + name string + req *example.MakeHatsReq + } + testReqs := []testReq{ + { + name: `OneHundred`, + req: &example.MakeHatsReq{Inches: 5, Quantity: 100}, + }, + { + name: `OneThousand`, + req: &example.MakeHatsReq{Inches: 5, Quantity: 1000}, + }, + { + name: `TenThousand`, + req: &example.MakeHatsReq{Inches: 5, Quantity: 10000}, + }, + { + name: `OneHundredThousand`, + req: &example.MakeHatsReq{Inches: 5, Quantity: 100000}, + }, + // // OneMillion takes 6+ seconds if server is Flush()ing after every message, <1sec if no flushing + // { + // name: `OneMillion`, + // req: &example.MakeHatsReq{Inches: 5, Quantity: 1000000}, + // }, + } + + for _, cc := range clients() { + for _, re := range testReqs { + t.Run(re.name+cc.name, func(t *testing.T) { + reqSentAt := time.Now() + hatStream, err := cc.c.MakeHats(context.Background(), re.req) + if err != nil { + t.Fatalf(`MakeHats request failed: %#v (hatStream=%#v)`, err, hatStream) + } + ii := int32(0) + for ; true; ii++ { + _, err = hatStream.Next(context.Background()) + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + } + if ii != re.req.Quantity { + t.Fatalf(`Expected to receive %d hats, got %d`, re.req.Quantity, ii) + } + took := time.Now().Sub(reqSentAt) + t.Logf( + "Received %.1f kHats per second (%d hats in %f seconds)\n", + float64(ii)/took.Seconds()/1000, + ii, took.Seconds(), + ) + }) + } + } +} + +func BenchmarkMakeHatsProto(b *testing.B) { + benchmarkMakeHats(b, newProtoClient()) +} + +func BenchmarkMakeHatsJSON(b *testing.B) { + benchmarkMakeHats(b, newJSONClient()) +} + +func benchmarkMakeHats(b *testing.B, cc example.Haberdasher) { + reqSentAt := time.Now() + hatStream, err := cc.MakeHats( + context.Background(), + &example.MakeHatsReq{Inches: 8, Quantity: int32(b.N)}, + ) + if err != nil { + b.Fatal(err) + } + ii := 0 + for ; true; ii++ { + _, err = hatStream.Next(context.Background()) + if err != nil { + if err == io.EOF { + break + } + b.Fatal(err) + } + } + if ii != b.N { + b.Fatalf(`Expected to receive %d hats, got %d`, b.N, ii) + } + took := time.Now().Sub(reqSentAt) + b.Logf( + "Received %.1f kHats per second (%d hats in %f seconds)\n", + float64(ii)/took.Seconds()/1000, + ii, took.Seconds(), + ) +} diff --git a/example/cmd/server/random_hat_stream.go b/example/cmd/server/random_hat_stream.go new file mode 100644 index 00000000..1ce70a59 --- /dev/null +++ b/example/cmd/server/random_hat_stream.go @@ -0,0 +1,85 @@ +// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not +// use this file except in compliance with the License. A copy of the License is +// located at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// or in the "license" file anumSentompanying this file. This file is distributed on +// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package main + +import ( + "context" + "io" + "log" + "time" + + "github.com/twitchtv/twirp" + "github.com/twitchtv/twirp/example" +) + +type randomHatStream struct { + inches, quantity, numSent int32 + startedAt time.Time + quiet bool +} + +func newRandomHatStream(inches, quantity int32, quiet bool) *randomHatStream { + return &randomHatStream{ + inches: inches, + quantity: quantity, + startedAt: time.Now(), + quiet: quiet, + } +} + +func (hs *randomHatStream) Next(ctx context.Context) (*example.Hat, error) { + defer func() { hs.numSent++ }() + if hs.numSent == hs.quantity { + if !hs.quiet { + log.Printf( + "[%4.1f khps] (%7d) Sending %v\n", + float64(hs.numSent)/time.Now().Sub(hs.startedAt).Seconds()/1000, + hs.numSent, io.EOF, + ) + } + return nil, io.EOF + } + + select { + case <-ctx.Done(): + err := errAborted(ctx.Err()) + if !hs.quiet { + log.Printf(`Context canceled: %#v`, ctx.Err()) + } + return nil, err + default: + hat := newRandomHat(hs.inches) + if !hs.quiet && hs.numSent%10000 == 0 && hs.numSent > 0 { + log.Printf( + "[%4.1f khps] (%7d) Sending %#v\n", + float64(hs.numSent)/time.Now().Sub(hs.startedAt).Seconds()/1000, + hs.numSent, hat, + ) + } + return hat, nil + } +} + +func (hs *randomHatStream) End(err error) { + if !hs.quiet { + log.Printf("randomHatStream ended with %#v\n", err) + } +} + +func errAborted(err error) error { + if err == nil { + return twirp.NewError(twirp.Aborted, `canceled`).WithMeta(`cause`, `unknown`) + } + return twirp.NewError(twirp.Aborted, err.Error()) +} diff --git a/example/service.twirp.go b/example/service.twirp.go index 7e6e1063..5bdb53b2 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -325,10 +325,14 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "MakeHat") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -336,13 +340,13 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Size) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } @@ -409,11 +413,42 @@ func (s *haberdasherServer) serveMakeHatsJSON(ctx context.Context, resp http.Res } func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - var err error + var ( + err error + respStream HatStream + ) + + // Prepare trailer + trailer := proto.NewBuffer(nil) + _ = trailer.EncodeVarint((2 << 3) | 2) // field tag + writeProtoError := func(err error) { + // JSON encode err as twirp err + // TODO: figure out what to do about updating context and headers + if err == io.EOF { + trailer.EncodeStringBytes("EOF") + return + } + twerr, ok := err.(twirp.Error) + if !ok { + twerr = twirp.InternalErrorWith(err) + } + _ = trailer.EncodeStringBytes( + string(marshalErrorToJSON(twerr)), + ) + } + defer func() { // Send trailer + _, err = resp.Write(trailer.Bytes()) + if err != nil { + // TODO: call error hook? + err = wrapErr(err, "failed to write trailer") + respStream.End(err) + } + }() + ctx = ctxsetters.WithMethodName(ctx, "MakeHats") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -421,23 +456,22 @@ func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(MakeHatsReq) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } // Call service method - var respStream HatStream func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { - s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + writeProtoError(twirp.InternalError("Internal service panic")) panic(r) } }() @@ -445,32 +479,23 @@ func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http }() if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } if respStream == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil MakeHatsReq and nil error while calling MakeHats. nil responses are not supported")) + writeProtoError(twirp.InternalError("received a nil MakeHatsReq and nil error while calling MakeHats. nil responses are not supported")) return } - respFlusher, canFlush := resp.(http.Flusher) - ctx = callResponsePrepared(ctx, s.hooks) + respFlusher, canFlush := resp.(http.Flusher) messages := proto.NewBuffer(nil) - - trailer := proto.NewBuffer(nil) - _ = trailer.EncodeVarint((2 << 3) | 2) // field tag for { msg, err := respStream.Next(ctx) if err != nil { - // TODO: figure out trailers' proto encoding beyond just a string - if err == io.EOF { - _ = trailer.EncodeStringBytes("OK") - } else { - _ = trailer.EncodeStringBytes(err.Error()) - } + writeProtoError(err) break } @@ -491,6 +516,7 @@ func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http } if canFlush { + // TODO: come up with a batching scheme to improve performance under high load respFlusher.Flush() } @@ -993,12 +1019,7 @@ func (r protoStreamReader) Read(msg proto.Message) error { trailerTag = (2 << 3) | 2 ) - if tag == trailerTag { - _ = r.c.Close() - return io.EOF - } - - if tag != msgTag { + if tag != msgTag && tag != trailerTag { return fmt.Errorf("invalid field tag: %v", tag) } @@ -1010,19 +1031,45 @@ func (r protoStreamReader) Read(msg proto.Message) error { if int(l) < 0 || int(l) > r.maxSize { return io.ErrShortBuffer } - buf := make([]byte, int(l)) + if tag == msgTag { + buf := make([]byte, int(l)) + + // Go ahead and read a message. + _, err = io.ReadFull(r.r, buf) + if err != nil { + return err + } - // Go ahead and read a message. + err = proto.Unmarshal(buf, msg) + if err != nil { + return err + } + return nil + } + + // This is a trailer, read it and then close the client + defer r.c.Close() + buf := make([]byte, int(l)) _, err = io.ReadFull(r.r, buf) if err != nil { return err } - err = proto.Unmarshal(buf, msg) + // Put the length back in front of the trailer so it can be decoded + buf = append(proto.EncodeVarint(l), buf...) + var trailer string + trailer, err = proto.NewBuffer(buf).DecodeStringBytes() if err != nil { - return err + return clientError("failed to read stream trailer", err) } - return nil + if trailer == "EOF" { + return io.EOF + } + var tj twerrJSON + if err = json.Unmarshal([]byte(trailer), &tj); err != nil { + return clientError("unable to decode stream trailer", err) + } + return tj.toTwirpError() } type jsonStreamReader struct { diff --git a/internal/twirptest/gogo_compat/service.twirp.go b/internal/twirptest/gogo_compat/service.twirp.go index eab0108a..341a2bd3 100644 --- a/internal/twirptest/gogo_compat/service.twirp.go +++ b/internal/twirptest/gogo_compat/service.twirp.go @@ -255,10 +255,14 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -266,13 +270,13 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } diff --git a/internal/twirptest/importable/importable.twirp.go b/internal/twirptest/importable/importable.twirp.go index 55137de2..3c23ec05 100644 --- a/internal/twirptest/importable/importable.twirp.go +++ b/internal/twirptest/importable/importable.twirp.go @@ -254,10 +254,14 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -265,13 +269,13 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } diff --git a/internal/twirptest/importer/importer.twirp.go b/internal/twirptest/importer/importer.twirp.go index f43934df..d3a19d44 100644 --- a/internal/twirptest/importer/importer.twirp.go +++ b/internal/twirptest/importer/importer.twirp.go @@ -263,10 +263,14 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -274,13 +278,13 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(twirp_internal_twirptest_importable.Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } @@ -347,11 +351,11 @@ func (s *svc2Server) serveStreamJSON(ctx context.Context, resp http.ResponseWrit } func (s *svc2Server) serveStreamProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - var err error + ctx = ctxsetters.WithMethodName(ctx, "Stream") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -846,12 +850,7 @@ func (r protoStreamReader) Read(msg proto.Message) error { trailerTag = (2 << 3) | 2 ) - if tag == trailerTag { - _ = r.c.Close() - return io.EOF - } - - if tag != msgTag { + if tag != msgTag && tag != trailerTag { return fmt.Errorf("invalid field tag: %v", tag) } @@ -863,19 +862,45 @@ func (r protoStreamReader) Read(msg proto.Message) error { if int(l) < 0 || int(l) > r.maxSize { return io.ErrShortBuffer } - buf := make([]byte, int(l)) + if tag == msgTag { + buf := make([]byte, int(l)) - // Go ahead and read a message. + // Go ahead and read a message. + _, err = io.ReadFull(r.r, buf) + if err != nil { + return err + } + + err = proto.Unmarshal(buf, msg) + if err != nil { + return err + } + return nil + } + + // This is a trailer, read it and then close the client + defer r.c.Close() + buf := make([]byte, int(l)) _, err = io.ReadFull(r.r, buf) if err != nil { return err } - err = proto.Unmarshal(buf, msg) + // Put the length back in front of the trailer so it can be decoded + buf = append(proto.EncodeVarint(l), buf...) + var trailer string + trailer, err = proto.NewBuffer(buf).DecodeStringBytes() if err != nil { - return err + return clientError("failed to read stream trailer", err) } - return nil + if trailer == "EOF" { + return io.EOF + } + var tj twerrJSON + if err = json.Unmarshal([]byte(trailer), &tj); err != nil { + return clientError("unable to decode stream trailer", err) + } + return tj.toTwirpError() } type jsonStreamReader struct { diff --git a/internal/twirptest/multiple/multiple1.twirp.go b/internal/twirptest/multiple/multiple1.twirp.go index 17fc3e54..afe9bd5e 100644 --- a/internal/twirptest/multiple/multiple1.twirp.go +++ b/internal/twirptest/multiple/multiple1.twirp.go @@ -255,10 +255,14 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -266,13 +270,13 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Msg1) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } diff --git a/internal/twirptest/multiple/multiple2.twirp.go b/internal/twirptest/multiple/multiple2.twirp.go index f3d14f49..68191aae 100644 --- a/internal/twirptest/multiple/multiple2.twirp.go +++ b/internal/twirptest/multiple/multiple2.twirp.go @@ -263,10 +263,14 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -274,13 +278,13 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Msg2) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } @@ -407,10 +411,14 @@ func (s *svc2Server) serveSamePackageProtoImportJSON(ctx context.Context, resp h func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "SamePackageProtoImport") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -418,13 +426,13 @@ func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, re buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Msg1) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } diff --git a/internal/twirptest/no_package_name/no_package_name.twirp.go b/internal/twirptest/no_package_name/no_package_name.twirp.go index 7a18613e..49cc2080 100644 --- a/internal/twirptest/no_package_name/no_package_name.twirp.go +++ b/internal/twirptest/no_package_name/no_package_name.twirp.go @@ -251,10 +251,14 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -262,13 +266,13 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } diff --git a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go index f468092a..81af039a 100644 --- a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go +++ b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go @@ -253,10 +253,14 @@ func (s *svc2Server) serveMethodJSON(ctx context.Context, resp http.ResponseWrit func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Method") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -264,13 +268,13 @@ func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.Response buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(no_package_name.Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } diff --git a/internal/twirptest/proto/proto.twirp.go b/internal/twirptest/proto/proto.twirp.go index 31287117..b440297b 100644 --- a/internal/twirptest/proto/proto.twirp.go +++ b/internal/twirptest/proto/proto.twirp.go @@ -254,10 +254,14 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -265,13 +269,13 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index 1b4b8fb8..5bf4192f 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -255,10 +255,14 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "MakeHat") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -266,13 +270,13 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Size) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } @@ -628,10 +632,14 @@ func (s *streamerServer) serveTransactJSON(ctx context.Context, resp http.Respon func (s *streamerServer) serveTransactProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error + writeProtoError := func(err error) { + s.writeError(ctx, resp, err) + } + ctx = ctxsetters.WithMethodName(ctx, "Transact") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -639,13 +647,13 @@ func (s *streamerServer) serveTransactProtobuf(ctx context.Context, resp http.Re buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Req) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } @@ -712,11 +720,11 @@ func (s *streamerServer) serveUploadJSON(ctx context.Context, resp http.Response } func (s *streamerServer) serveUploadProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - var err error + ctx = ctxsetters.WithMethodName(ctx, "Upload") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -745,11 +753,42 @@ func (s *streamerServer) serveDownloadJSON(ctx context.Context, resp http.Respon } func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - var err error + var ( + err error + respStream RespStream + ) + + // Prepare trailer + trailer := proto.NewBuffer(nil) + _ = trailer.EncodeVarint((2 << 3) | 2) // field tag + writeProtoError := func(err error) { + // JSON encode err as twirp err + // TODO: figure out what to do about updating context and headers + if err == io.EOF { + trailer.EncodeStringBytes("EOF") + return + } + twerr, ok := err.(twirp.Error) + if !ok { + twerr = twirp.InternalErrorWith(err) + } + _ = trailer.EncodeStringBytes( + string(marshalErrorToJSON(twerr)), + ) + } + defer func() { // Send trailer + _, err = resp.Write(trailer.Bytes()) + if err != nil { + // TODO: call error hook? + err = wrapErr(err, "failed to write trailer") + respStream.End(err) + } + }() + ctx = ctxsetters.WithMethodName(ctx, "Download") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -757,23 +796,22 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } reqContent := new(Req) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - s.writeError(ctx, resp, twirp.InternalErrorWith(err)) + writeProtoError(twirp.InternalErrorWith(err)) return } // Call service method - var respStream RespStream func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { - s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + writeProtoError(twirp.InternalError("Internal service panic")) panic(r) } }() @@ -781,32 +819,23 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re }() if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } if respStream == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil Req and nil error while calling Download. nil responses are not supported")) + writeProtoError(twirp.InternalError("received a nil Req and nil error while calling Download. nil responses are not supported")) return } - respFlusher, canFlush := resp.(http.Flusher) - ctx = callResponsePrepared(ctx, s.hooks) + respFlusher, canFlush := resp.(http.Flusher) messages := proto.NewBuffer(nil) - - trailer := proto.NewBuffer(nil) - _ = trailer.EncodeVarint((2 << 3) | 2) // field tag for { msg, err := respStream.Next(ctx) if err != nil { - // TODO: figure out trailers' proto encoding beyond just a string - if err == io.EOF { - _ = trailer.EncodeStringBytes("OK") - } else { - _ = trailer.EncodeStringBytes(err.Error()) - } + writeProtoError(err) break } @@ -827,6 +856,7 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re } if canFlush { + // TODO: come up with a batching scheme to improve performance under high load respFlusher.Flush() } @@ -863,11 +893,11 @@ func (s *streamerServer) serveCommunicateJSON(ctx context.Context, resp http.Res } func (s *streamerServer) serveCommunicateProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - var err error + ctx = ctxsetters.WithMethodName(ctx, "Communicate") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - s.writeError(ctx, resp, err) + writeProtoError(err) return } @@ -1399,12 +1429,7 @@ func (r protoStreamReader) Read(msg proto.Message) error { trailerTag = (2 << 3) | 2 ) - if tag == trailerTag { - _ = r.c.Close() - return io.EOF - } - - if tag != msgTag { + if tag != msgTag && tag != trailerTag { return fmt.Errorf("invalid field tag: %v", tag) } @@ -1416,19 +1441,45 @@ func (r protoStreamReader) Read(msg proto.Message) error { if int(l) < 0 || int(l) > r.maxSize { return io.ErrShortBuffer } - buf := make([]byte, int(l)) + if tag == msgTag { + buf := make([]byte, int(l)) - // Go ahead and read a message. + // Go ahead and read a message. + _, err = io.ReadFull(r.r, buf) + if err != nil { + return err + } + + err = proto.Unmarshal(buf, msg) + if err != nil { + return err + } + return nil + } + + // This is a trailer, read it and then close the client + defer r.c.Close() + buf := make([]byte, int(l)) _, err = io.ReadFull(r.r, buf) if err != nil { return err } - err = proto.Unmarshal(buf, msg) + // Put the length back in front of the trailer so it can be decoded + buf = append(proto.EncodeVarint(l), buf...) + var trailer string + trailer, err = proto.NewBuffer(buf).DecodeStringBytes() if err != nil { - return err + return clientError("failed to read stream trailer", err) } - return nil + if trailer == "EOF" { + return io.EOF + } + var tj twerrJSON + if err = json.Unmarshal([]byte(trailer), &tj); err != nil { + return clientError("unable to decode stream trailer", err) + } + return tj.toTwirpError() } type jsonStreamReader struct { diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index dea2b1ef..8194d9d7 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -769,12 +769,7 @@ func (t *twirp) generateStreamUtils() { t.P(` trailerTag = (2 << 3) | 2`) t.P(` )`) t.P(``) - t.P(` if tag == trailerTag {`) - t.P(` _ = r.c.Close()`) - t.P(` return `, t.pkgs["io"], `.EOF`) - t.P(` }`) - t.P(``) - t.P(` if tag != msgTag {`) + t.P(` if tag != msgTag && tag != trailerTag {`) t.P(` return `, t.pkgs["fmt"], `.Errorf("invalid field tag: %v", tag)`) t.P(` }`) t.P(``) @@ -786,26 +781,52 @@ func (t *twirp) generateStreamUtils() { t.P(` if int(l) < 0 || int(l) > r.maxSize {`) t.P(` return `, t.pkgs["io"], `.ErrShortBuffer`) t.P(` }`) + t.P(` if tag == msgTag {`) + t.P(` buf := make([]byte, int(l))`) + t.P() + t.P(` // Go ahead and read a message.`) + t.P(` _, err = io.ReadFull(r.r, buf)`) + t.P(` if err != nil {`) + t.P(` return err`) + t.P(` }`) + t.P() + t.P(` err = proto.Unmarshal(buf, msg)`) + t.P(` if err != nil {`) + t.P(` return err`) + t.P(` }`) + t.P(` return nil`) + t.P(` }`) + t.P() + t.P(` // This is a trailer, read it and then close the client`) + t.P(` defer r.c.Close()`) t.P(` buf := make([]byte, int(l))`) - t.P(``) - t.P(` // Go ahead and read a message.`) t.P(` _, err = `, t.pkgs["io"], `.ReadFull(r.r, buf)`) t.P(` if err != nil {`) t.P(` return err`) t.P(` }`) - t.P(``) - t.P(` err = `, t.pkgs["proto"], `.Unmarshal(buf, msg)`) + t.P() + t.P(` // Put the length back in front of the trailer so it can be decoded`) + t.P(` buf = append(proto.EncodeVarint(l), buf...)`) + t.P(` var trailer string`) + t.P(` trailer, err = `, t.pkgs["proto"], `.NewBuffer(buf).DecodeStringBytes()`) t.P(` if err != nil {`) - t.P(` return err`) + t.P(` return clientError("failed to read stream trailer", err)`) + t.P(` }`) + t.P(` if trailer == "EOF" {`) + t.P(` return io.EOF`) t.P(` }`) - t.P(` return nil`) + t.P(` var tj twerrJSON`) + t.P(` if err = json.Unmarshal([]byte(trailer), &tj); err != nil {`) + t.P(` return clientError("unable to decode stream trailer", err)`) + t.P(` }`) + t.P(` return tj.toTwirpError()`) t.P(`}`) t.P() t.P(`type jsonStreamReader struct {`) t.P(` dec *`, t.pkgs["json"], `.Decoder`) t.P(` unmarshaler *`, t.pkgs["jsonpb"], `.Unmarshaler`) - t.P(` messageStreamDone bool`) + t.P(` messageStreamDone bool`) t.P(`}`) t.P() @@ -1308,31 +1329,72 @@ func (t *twirp) generateServerJSONMethod(service *descriptor.ServiceDescriptorPr func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescriptorProto, method *descriptor.MethodDescriptorProto) { servStruct := serviceStruct(service) methName := stringutils.CamelCase(method.GetName()) + rpcType := methodRPCType(method) t.P(`func (s *`, servStruct, `) serve`, methName, `Protobuf(ctx `, t.pkgs["context"], `.Context, resp `, t.pkgs["http"], `.ResponseWriter, req *`, t.pkgs["http"], `.Request) {`) - t.P(` var err error`) - t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithMethodName(ctx, "`, methName, `")`) - t.P(` ctx, err = callRequestRouted(ctx, s.hooks)`) - t.P(` if err != nil {`) - t.P(` s.writeError(ctx, resp, err)`) - t.P(` return`) - t.P(` }`) + if rpcType == unary { + t.P(` var err error`) + t.P(` writeProtoError := func(err error) {`) + t.P(` s.writeError(ctx, resp, err)`) + t.P(` }`) + } + if rpcType == download { + t.P(` var (`) + t.P(` err error`) + t.P(` respStream `, t.methodOutputType(method)) + t.P(` )`) + t.P() + t.P(` // Prepare trailer`) + t.P(` trailer := proto.NewBuffer(nil)`) + t.P(` _ = trailer.EncodeVarint((2 << 3) | 2) // field tag`) + t.P(` writeProtoError := func(err error) {`) + t.P(` // JSON encode err as twirp err`) + t.P(` // TODO: figure out what to do about updating context and headers`) + t.P(` if err == io.EOF {`) + t.P(` trailer.EncodeStringBytes("EOF")`) + t.P(` return`) + t.P(` }`) + t.P(` twerr, ok := err.(twirp.Error)`) + t.P(` if !ok {`) + t.P(` twerr = twirp.InternalErrorWith(err)`) + t.P(` }`) + t.P(` _ = trailer.EncodeStringBytes(`) + t.P(` string(marshalErrorToJSON(twerr)),`) + t.P(` )`) + t.P(` }`) + t.P(` defer func() { // Send trailer`) + t.P(` _, err = resp.Write(trailer.Bytes())`) + t.P(` if err != nil {`) + t.P(` // TODO: call error hook?`) + t.P(` err = wrapErr(err, "failed to write trailer")`) + t.P(` respStream.End(err)`) + t.P(` }`) + t.P(` }()`) + } + + t.P() + t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithMethodName(ctx, "`, methName, `")`) + t.P(` ctx, err = callRequestRouted(ctx, s.hooks)`) + t.P(` if err != nil {`) + t.P(` writeProtoError(err)`) + t.P(` return`) + t.P(` }`) + t.P() t.P(` resp.Header().Set("Content-Type", "application/protobuf")`) - rpcType := methodRPCType(method) if rpcType == unary || rpcType == download { t.P(` buf, err := `, t.pkgs["ioutil"], `.ReadAll(req.Body)`) t.P(` if err != nil {`) t.P(` err = wrapErr(err, "failed to read request body")`) - t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalErrorWith(err))`) + t.P(` writeProtoError(`, t.pkgs["twirp"], `.InternalErrorWith(err))`) t.P(` return`) t.P(` }`) t.P(` reqContent := new(`, t.methodInputType(method), `)`) t.P(` if err = `, t.pkgs["proto"], `.Unmarshal(buf, reqContent); err != nil {`) t.P(` err = wrapErr(err, "failed to parse request proto")`) - t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalErrorWith(err))`) + t.P(` writeProtoError(`, t.pkgs["twirp"], `.InternalErrorWith(err))`) t.P(` return`) t.P(` }`) t.P() @@ -1381,12 +1443,11 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript if rpcType == download { t.P(` // Call service method`) - t.P(` var respStream `, t.methodOutputType(method)) t.P(` func() {`) t.P(` defer func() {`) t.P(` // In case of a panic, serve a 500 error and then panic.`) t.P(` if r := recover(); r != nil {`) - t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalError("Internal service panic"))`) + t.P(` writeProtoError(`, t.pkgs["twirp"], `.InternalError("Internal service panic"))`) t.P(` panic(r)`) t.P(` }`) t.P(` }()`) @@ -1394,32 +1455,23 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` }()`) t.P(``) t.P(` if err != nil {`) - t.P(` s.writeError(ctx, resp, err)`) + t.P(` writeProtoError(err)`) t.P(` return`) t.P(` }`) t.P(``) t.P(` if respStream == nil {`) - t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalError("received a nil `, t.methodInputType(method), ` and nil error while calling `, methName, `. nil responses are not supported"))`) + t.P(` writeProtoError(`, t.pkgs["twirp"], `.InternalError("received a nil `, t.methodInputType(method), ` and nil error while calling `, methName, `. nil responses are not supported"))`) t.P(` return`) t.P(` }`) t.P(``) - t.P(` respFlusher, canFlush := resp.(http.Flusher)`) - t.P(``) t.P(` ctx = callResponsePrepared(ctx, s.hooks)`) t.P(``) + t.P(` respFlusher, canFlush := resp.(http.Flusher)`) t.P(` messages := `, t.pkgs["proto"], `.NewBuffer(nil)`) - t.P(``) - t.P(` trailer := `, t.pkgs["proto"], `.NewBuffer(nil)`) - t.P(` _ = trailer.EncodeVarint((2 << 3) | 2) // field tag`) t.P(` for {`) t.P(` msg, err := respStream.Next(ctx)`) t.P(` if err != nil {`) - t.P(` // TODO: figure out trailers' proto encoding beyond just a string`) - t.P(` if err == `, t.pkgs["io"], `.EOF {`) - t.P(` _ = trailer.EncodeStringBytes("OK")`) - t.P(` } else {`) - t.P(` _ = trailer.EncodeStringBytes(err.Error())`) - t.P(` }`) + t.P(` writeProtoError(err)`) t.P(` break`) t.P(` }`) t.P(``) @@ -1440,18 +1492,12 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` }`) t.P(``) t.P(` if canFlush {`) + t.P(` // TODO: come up with a batching scheme to improve performance under high load`) t.P(` respFlusher.Flush()`) t.P(` }`) t.P(``) t.P(` // TODO: Call a hook that we sent a message in a stream?`) t.P(` }`) - t.P(``) - t.P(` _, err = resp.Write(trailer.Bytes())`) - t.P(` if err != nil {`) - t.P(` // TODO: call error hook?`) - t.P(` err = wrapErr(err, "failed to write trailer")`) - t.P(` respStream.End(err)`) - t.P(` }`) } t.P(`}`) From 966c23a912d22f26bce62e1271547a33a7f58e20 Mon Sep 17 00:00:00 2001 From: Mike Lemmon Date: Mon, 28 May 2018 01:38:07 -0700 Subject: [PATCH 3/7] Improve error handling in generated download-stream protobuf code * Add new "test_example" target to Makefile. Example tests are passing for protobuf clients; JSON tests would fail, but are commented out * Return normal twirp errors from download requests that fail before the stream is created * Consolidate generator logic for unary and download rpc types in generateServerProtobufMethod * In the example server, use chan of HatOrError type to send mid-stream errors --- Makefile | 7 +- .../clientcompat/clientcompat.twirp.go | 42 +-- example/cmd/client/main.go | 35 +- example/cmd/server/hat_stream_sender.go | 55 ++++ example/cmd/server/main.go | 46 ++- example/cmd/server/main_test.go | 106 +++--- example/cmd/server/random_hat_stream.go | 85 ----- example/service.twirp.go | 208 ++++++------ .../twirptest/gogo_compat/service.twirp.go | 21 +- .../twirptest/importable/importable.twirp.go | 21 +- internal/twirptest/importer/importer.twirp.go | 97 +++--- .../twirptest/multiple/multiple1.twirp.go | 21 +- .../twirptest/multiple/multiple2.twirp.go | 42 +-- .../no_package_name/no_package_name.twirp.go | 21 +- .../no_package_name_importer.twirp.go | 21 +- internal/twirptest/proto/proto.twirp.go | 21 +- internal/twirptest/service.twirp.go | 249 +++++++------- protoc-gen-twirp/generator.go | 309 +++++++++--------- 18 files changed, 702 insertions(+), 705 deletions(-) create mode 100644 example/cmd/server/hat_stream_sender.go delete mode 100644 example/cmd/server/random_hat_stream.go diff --git a/Makefile b/Makefile index d99096e0..8837a91e 100644 --- a/Makefile +++ b/Makefile @@ -11,11 +11,11 @@ generate: PATH=$(CURDIR)/_tools/bin:$(PATH) GOBIN="${PWD}/bin" go install -v ./protoc-gen-... $(RETOOL) do go generate ./... -test_all: setup test_core test_clients +test_all: setup test_core test_clients test_example test_core: generate # $(RETOOL) do errcheck -blank ./internal/twirptest - go test -race $(shell go list ./... | grep -v /vendor/ | grep -v /_tools/) + go test -race $(shell go list ./... | grep -v /vendor/ | grep -v /_tools/ | grep -v /example/) test_clients: test_go_client test_python_client @@ -25,6 +25,9 @@ test_go_client: generate build/clientcompat build/gocompat test_python_client: generate build/clientcompat build/pycompat ./build/clientcompat -client ./build/pycompat +test_example: generate + go test -race -bench=. $(shell go list ./example/...) + setup: ./install_proto.bash GOPATH=$(CURDIR)/_tools go install github.com/twitchtv/retool/... diff --git a/clientcompat/internal/clientcompat/clientcompat.twirp.go b/clientcompat/internal/clientcompat/clientcompat.twirp.go index c956416a..7a53b38d 100644 --- a/clientcompat/internal/clientcompat/clientcompat.twirp.go +++ b/clientcompat/internal/clientcompat/clientcompat.twirp.go @@ -276,38 +276,37 @@ func (s *compatServiceServer) serveMethodJSON(ctx context.Context, resp http.Res func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Method") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Req) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Resp + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -319,12 +318,11 @@ func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Resp and nil error while calling Method. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Req and nil error while calling Method. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -333,12 +331,15 @@ func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } @@ -424,38 +425,37 @@ func (s *compatServiceServer) serveNoopMethodJSON(ctx context.Context, resp http func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "NoopMethod") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Empty) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Empty + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -467,12 +467,11 @@ func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Empty and nil error while calling NoopMethod. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Empty and nil error while calling NoopMethod. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -481,12 +480,15 @@ func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } diff --git a/example/cmd/client/main.go b/example/cmd/client/main.go index e1a3cf48..c15d1121 100644 --- a/example/cmd/client/main.go +++ b/example/cmd/client/main.go @@ -28,12 +28,13 @@ func main() { client := example.NewHaberdasherProtobufClient("http://localhost:8080", &http.Client{}) var ( - hat *example.Hat - hatStream example.HatStream - err error + hat *example.Hat + err error ) + // // Call the MakeHat rpc + // for i := 0; i < 5; i++ { hat, err = client.MakeHat(context.Background(), &example.Size{Inches: 12}) if err != nil { @@ -52,10 +53,15 @@ func main() { log.Println(`Response from MakeHat:`) log.Printf("\t%+v\n", hat) + // // Call the MakeHats streaming rpc + // + const ( + printEvery = 50000 + quantity = int32(300000) + ) reqSentAt := time.Now() - quantity := int32(300000) - hatStream, err = client.MakeHats( + hatStream, err := client.MakeHats( context.Background(), &example.MakeHatsReq{Inches: 12, Quantity: quantity}, ) @@ -66,13 +72,9 @@ func main() { ii := 1 printResults := func() { took := time.Now().Sub(reqSentAt) - log.Printf( - "Received %.1f kHats per second (%d hats in %f seconds)\n", - float64(ii-1)/took.Seconds()/1000, - ii-1, took.Seconds(), - ) + khps := float64(ii-1) / took.Seconds() / 1000 + log.Printf("Received %.1f kHats per second (%d hats in %f seconds)\n", khps, ii-1, took.Seconds()) } - defer printResults() for ; true; ii++ { // Receive all the hats hat, err = hatStream.Next(context.Background()) if err != nil { @@ -82,13 +84,10 @@ func main() { printResults() log.Fatal(err) } - if ii%50000 == 0 { - log.Printf( - "\t[%4.1f khps] %6d %+v\n", - float64(ii)/time.Now().Sub(reqSentAt).Seconds()/1000, - ii, hat, - ) + if ii%printEvery == 0 { + khps := float64(ii) / time.Now().Sub(reqSentAt).Seconds() / 1000 + log.Printf("\t[%4.1f khps] %6d %+v\n", khps, ii, hat) } } - + printResults() } diff --git a/example/cmd/server/hat_stream_sender.go b/example/cmd/server/hat_stream_sender.go new file mode 100644 index 00000000..171ef7f3 --- /dev/null +++ b/example/cmd/server/hat_stream_sender.go @@ -0,0 +1,55 @@ +// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not +// use this file except in compliance with the License. A copy of the License is +// located at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// or in the "license" file accompanying this file. This file is distributed on +// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package main + +// This file is an alternate implementation of the twirp-provided stream +// constructor originally proposed in the "As the sender" section of +// https://github.com/twitchtv/twirp/issues/70#issuecomment-361365458 + +import ( + "context" + "io" + + "github.com/twitchtv/twirp/example" +) + +type HatOrError struct { + hat *example.Hat + err error +} + +func NewHatStream(ch chan HatOrError) *hatStreamSender { + return &hatStreamSender{ch: ch} +} + +type hatStreamSender struct { + ch <-chan HatOrError +} + +func (hs *hatStreamSender) Next(ctx context.Context) (*example.Hat, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case v, open := <-hs.ch: + if !open { + return nil, io.EOF + } + if v.err != nil { + return nil, v.err + } + return v.hat, nil + } +} + +func (hs *hatStreamSender) End(err error) {} // Should anything go here? diff --git a/example/cmd/server/main.go b/example/cmd/server/main.go index 2e522387..85d97168 100644 --- a/example/cmd/server/main.go +++ b/example/cmd/server/main.go @@ -25,36 +25,50 @@ import ( "github.com/twitchtv/twirp/hooks/statsd" ) -func newRandomHat(inches int32) *example.Hat { - return &example.Hat{ - Size: inches, - Color: []string{"white", "black", "brown", "red", "blue"}[rand.Intn(5)], - Name: []string{"bowler", "baseball cap", "top hat", "derby"}[rand.Intn(4)], - } -} - -type randomHaberdasher struct{ quiet bool } +type randomHaberdasher struct{} var ( errTooSmall = twirp.InvalidArgumentError("Inches", "I can't make hats that small!") errNegativeQuantity = twirp.InvalidArgumentError("Quantity", "I can't make a negative quantity of hats!") ) -func (h *randomHaberdasher) MakeHat(ctx context.Context, size *example.Size) (*example.Hat, error) { - if size.Inches <= 0 { +func newRandomHat(inches int32) (*example.Hat, error) { + if inches <= 0 { return nil, errTooSmall } - return newRandomHat(size.Inches), nil + return &example.Hat{ + Size: inches, + Color: []string{"white", "black", "brown", "red", "blue"}[rand.Intn(5)], + Name: []string{"bowler", "baseball cap", "top hat", "derby"}[rand.Intn(4)], + }, nil +} + +func (h *randomHaberdasher) MakeHat(ctx context.Context, size *example.Size) (*example.Hat, error) { + return newRandomHat(size.Inches) } func (h *randomHaberdasher) MakeHats(ctx context.Context, req *example.MakeHatsReq) (example.HatStream, error) { - if req.Inches <= 0 { - return nil, errTooSmall - } if req.Quantity < 0 { return nil, errNegativeQuantity } - return newRandomHatStream(req.Inches, req.Quantity, h.quiet), nil + // Normally we'd validate Inches here as well, but we let it fall through to error on newRandomHat to demonstrate mid-stream errors + // if req.Inches <= 0 { + // return nil, errTooSmall + // } + + ch := make(chan HatOrError, 100) // NB: the size of this buffer can make a big difference! + go func() { + for ii := int32(0); ii < req.Quantity; ii++ { + hat, err := newRandomHat(req.Inches) + select { + case <-ctx.Done(): + return + case ch <- HatOrError{hat, err}: + } + } + close(ch) + }() + return NewHatStream(ch), nil } func main() { diff --git a/example/cmd/server/main_test.go b/example/cmd/server/main_test.go index acbfb415..56a96a29 100644 --- a/example/cmd/server/main_test.go +++ b/example/cmd/server/main_test.go @@ -1,3 +1,16 @@ +// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not +// use this file except in compliance with the License. A copy of the License is +// located at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// or in the "license" file accompanying this file. This file is distributed on +// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + package main import ( @@ -19,7 +32,7 @@ func TestMain(m *testing.M) { } func runServer() { - server := example.NewHaberdasherServer(&randomHaberdasher{quiet: true}, nil) + server := example.NewHaberdasherServer(&randomHaberdasher{}, nil) log.Fatal(http.ListenAndServe(":8080", server)) } @@ -33,17 +46,23 @@ func newJSONClient() example.Haberdasher { type client struct { name string - c example.Haberdasher + svc example.Haberdasher } func clients() []client { return []client{ {`Proto`, newProtoClient()}, - {`JSON`, newJSONClient()}, + // {`JSON`, newJSONClient()}, } } func compareErrors(got, expected error) error { + if got == nil && expected == nil { + return nil + } + if got == nil || expected == nil { + return fmt.Errorf(`Expected err to be %#v, got %#v`, expected, got) + } if got.Error() == expected.Error() { return nil } @@ -57,69 +76,74 @@ func TestInvalidMakeHatsRequests(t *testing.T) { expected error } testReqs := []testReq{ - { - name: `TooSmall`, - req: &example.MakeHatsReq{Inches: -5}, - expected: errTooSmall, - }, { name: `NegativeQuantity`, req: &example.MakeHatsReq{Inches: 8, Quantity: -5}, expected: errNegativeQuantity, }, + // // TooSmall is currently not being checked before the stream is returned so this would fail + // { + // name: `TooSmall`, + // req: &example.MakeHatsReq{Inches: -5}, + // expected: errTooSmall, + // }, } for _, cc := range clients() { for _, re := range testReqs { - t.Run(re.name+cc.name, func(t *testing.T) { - hatStream, err := cc.c.MakeHats(context.Background(), re.req) - if err != nil { - t.Fatalf(`MakeHats request failed: %#v`, err) - } - _, err = hatStream.Next(context.Background()) + t.Run(re.name+`/`+cc.name, func(t *testing.T) { + hatStream, err := cc.svc.MakeHats(context.Background(), re.req) err = compareErrors(err, re.expected) if err != nil { t.Fatal(err) } + if hatStream != nil { + t.Fatalf(`expected hatStream to be nil, got %+v`, hatStream) + } }) } } } -func TestMakeHatsPerf(t *testing.T) { +func TestMakeHatsTooSmall(t *testing.T) { + for _, cc := range clients() { + t.Run(cc.name, func(t *testing.T) { + hatStream, err := cc.svc.MakeHats( + context.Background(), + &example.MakeHatsReq{Inches: -5, Quantity: 10}, + ) + err = compareErrors(err, nil) + if err != nil { + t.Fatal(err) + } + _, err = hatStream.Next(context.Background()) + err = compareErrors(err, errTooSmall) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestMakeHatsLargeQuantities(t *testing.T) { type testReq struct { name string req *example.MakeHatsReq } testReqs := []testReq{ - { - name: `OneHundred`, - req: &example.MakeHatsReq{Inches: 5, Quantity: 100}, - }, - { - name: `OneThousand`, - req: &example.MakeHatsReq{Inches: 5, Quantity: 1000}, - }, - { - name: `TenThousand`, - req: &example.MakeHatsReq{Inches: 5, Quantity: 10000}, - }, - { - name: `OneHundredThousand`, - req: &example.MakeHatsReq{Inches: 5, Quantity: 100000}, - }, - // // OneMillion takes 6+ seconds if server is Flush()ing after every message, <1sec if no flushing - // { - // name: `OneMillion`, - // req: &example.MakeHatsReq{Inches: 5, Quantity: 1000000}, - // }, + {name: `OneHundred`, req: &example.MakeHatsReq{Inches: 5, Quantity: 100}}, + {name: `OneThousand`, req: &example.MakeHatsReq{Inches: 5, Quantity: 1000}}, + {name: `TenThousand`, req: &example.MakeHatsReq{Inches: 5, Quantity: 10000}}, + {name: `OneHundredThousand`, req: &example.MakeHatsReq{Inches: 5, Quantity: 100000}}, + // {name: `OneMillion`, req: &example.MakeHatsReq{Inches: 5, Quantity: 1000000}}, + // // OneMillion takes 6+ seconds if server is flushing after every message, <1sec if no flushing } for _, cc := range clients() { for _, re := range testReqs { - t.Run(re.name+cc.name, func(t *testing.T) { + t.Run(re.name+`/`+cc.name, func(t *testing.T) { reqSentAt := time.Now() - hatStream, err := cc.c.MakeHats(context.Background(), re.req) + hatStream, err := cc.svc.MakeHats(context.Background(), re.req) if err != nil { t.Fatalf(`MakeHats request failed: %#v (hatStream=%#v)`, err, hatStream) } @@ -151,9 +175,9 @@ func BenchmarkMakeHatsProto(b *testing.B) { benchmarkMakeHats(b, newProtoClient()) } -func BenchmarkMakeHatsJSON(b *testing.B) { - benchmarkMakeHats(b, newJSONClient()) -} +// func BenchmarkMakeHatsJSON(b *testing.B) { +// benchmarkMakeHats(b, newJSONClient()) +// } func benchmarkMakeHats(b *testing.B, cc example.Haberdasher) { reqSentAt := time.Now() diff --git a/example/cmd/server/random_hat_stream.go b/example/cmd/server/random_hat_stream.go deleted file mode 100644 index 1ce70a59..00000000 --- a/example/cmd/server/random_hat_stream.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the License is -// located at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// or in the "license" file anumSentompanying this file. This file is distributed on -// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package main - -import ( - "context" - "io" - "log" - "time" - - "github.com/twitchtv/twirp" - "github.com/twitchtv/twirp/example" -) - -type randomHatStream struct { - inches, quantity, numSent int32 - startedAt time.Time - quiet bool -} - -func newRandomHatStream(inches, quantity int32, quiet bool) *randomHatStream { - return &randomHatStream{ - inches: inches, - quantity: quantity, - startedAt: time.Now(), - quiet: quiet, - } -} - -func (hs *randomHatStream) Next(ctx context.Context) (*example.Hat, error) { - defer func() { hs.numSent++ }() - if hs.numSent == hs.quantity { - if !hs.quiet { - log.Printf( - "[%4.1f khps] (%7d) Sending %v\n", - float64(hs.numSent)/time.Now().Sub(hs.startedAt).Seconds()/1000, - hs.numSent, io.EOF, - ) - } - return nil, io.EOF - } - - select { - case <-ctx.Done(): - err := errAborted(ctx.Err()) - if !hs.quiet { - log.Printf(`Context canceled: %#v`, ctx.Err()) - } - return nil, err - default: - hat := newRandomHat(hs.inches) - if !hs.quiet && hs.numSent%10000 == 0 && hs.numSent > 0 { - log.Printf( - "[%4.1f khps] (%7d) Sending %#v\n", - float64(hs.numSent)/time.Now().Sub(hs.startedAt).Seconds()/1000, - hs.numSent, hat, - ) - } - return hat, nil - } -} - -func (hs *randomHatStream) End(err error) { - if !hs.quiet { - log.Printf("randomHatStream ended with %#v\n", err) - } -} - -func errAborted(err error) error { - if err == nil { - return twirp.NewError(twirp.Aborted, `canceled`).WithMeta(`cause`, `unknown`) - } - return twirp.NewError(twirp.Aborted, err.Error()) -} diff --git a/example/service.twirp.go b/example/service.twirp.go index 5bdb53b2..6242b2f3 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -101,6 +101,12 @@ func (c *haberdasherProtobufClient) MakeHats(ctx context.Context, in *MakeHatsRe if err != nil { return nil, clientError("failed to do request", err) } + if err = ctx.Err(); err != nil { + return nil, clientError("aborted because context was done", err) + } + if resp.StatusCode != 200 { + return nil, errorFromResponse(resp) + } return &protoHatStreamReader{ prs: protoStreamReader{ @@ -170,6 +176,12 @@ func (c *haberdasherJSONClient) MakeHats(ctx context.Context, in *MakeHatsReq) ( if err != nil { return nil, clientError("failed to do request", err) } + if err = ctx.Err(); err != nil { + return nil, clientError("aborted because context was done", err) + } + if resp.StatusCode != 200 { + return nil, errorFromResponse(resp) + } jrs, err := newJSONStreamReader(resp.Body) if err != nil { @@ -325,38 +337,37 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "MakeHat") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Size) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Hat + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -368,12 +379,11 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Hat and nil error while calling MakeHat. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Size and nil error while calling MakeHat. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -382,12 +392,15 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } @@ -413,89 +426,90 @@ func (s *haberdasherServer) serveMakeHatsJSON(ctx context.Context, resp http.Res } func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - var ( - err error - respStream HatStream - ) - - // Prepare trailer - trailer := proto.NewBuffer(nil) - _ = trailer.EncodeVarint((2 << 3) | 2) // field tag - writeProtoError := func(err error) { - // JSON encode err as twirp err - // TODO: figure out what to do about updating context and headers - if err == io.EOF { - trailer.EncodeStringBytes("EOF") - return - } - twerr, ok := err.(twirp.Error) - if !ok { - twerr = twirp.InternalErrorWith(err) - } - _ = trailer.EncodeStringBytes( - string(marshalErrorToJSON(twerr)), - ) - } - defer func() { // Send trailer - _, err = resp.Write(trailer.Bytes()) - if err != nil { - // TODO: call error hook? - err = wrapErr(err, "failed to write trailer") - respStream.End(err) - } - }() - + var err error ctx = ctxsetters.WithMethodName(ctx, "MakeHats") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(MakeHatsReq) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method + var respContent HatStream + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { - writeProtoError(twirp.InternalError("Internal service panic")) + s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() - respStream, err = s.MakeHats(ctx, reqContent) + respContent, err = s.MakeHats(ctx, reqContent) }() if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - - if respStream == nil { - writeProtoError(twirp.InternalError("received a nil MakeHatsReq and nil error while calling MakeHats. nil responses are not supported")) + if respContent == nil { + s.writeError(ctx, resp, twirp.InternalError("received a nil MakeHatsReq and nil error while calling MakeHats. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) + ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") + resp.WriteHeader(http.StatusOK) + + // Prepare trailer + trailer := proto.NewBuffer(nil) + _ = trailer.EncodeVarint((2 << 3) | 2) // field tag + writeTrailer := func(err error) { + if err == io.EOF { + trailer.EncodeStringBytes("EOF") + return + } + // Write trailer as json-encoded twirp err + twerr, ok := err.(twirp.Error) + if !ok { + twerr = twirp.InternalErrorWith(err) + } + statusCode := twirp.ServerHTTPStatusFromErrorCode(twerr.Code()) + ctx = ctxsetters.WithStatusCode(ctx, statusCode) + ctx = callError(ctx, s.hooks, twerr) + if encodeErr := trailer.EncodeStringBytes(string(marshalErrorToJSON(twerr))); encodeErr != nil { + _ = trailer.EncodeStringBytes("{\"code\":\"" + string(twirp.Internal) + "\",\"msg\":\"There was an error but it could not be serialized into JSON\"}") // fallback + } + _, writeErr := resp.Write(trailer.Bytes()) + if writeErr != nil { + // Ignored, for the same reason as in the writeError func + _ = writeErr + } + respContent.End(twerr) + } - respFlusher, canFlush := resp.(http.Flusher) messages := proto.NewBuffer(nil) for { - msg, err := respStream.Next(ctx) + msg, err := respContent.Next(ctx) if err != nil { - writeProtoError(err) + writeTrailer(err) break } @@ -504,31 +518,29 @@ func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http err = messages.EncodeMessage(msg) if err != nil { err = wrapErr(err, "failed to marshal proto message") - respStream.End(err) + writeTrailer(err) break } _, err = resp.Write(messages.Bytes()) if err != nil { err = wrapErr(err, "failed to send proto message") - respStream.End(err) + writeTrailer(err) // likely to fail on write, but try anyway to ensure ctx gets error code set for responseSent hook break } if canFlush { - // TODO: come up with a batching scheme to improve performance under high load + // TODO: Come up with a batching scheme to improve performance under high load + // and/or provide a hook for the respStream to control flushing the response. + // Flushing after each message dramatically reduces high-load throughput -- + // difference can be more than 10x based on initial experiments respFlusher.Flush() } - // TODO: Call a hook that we sent a message in a stream? + // TODO: Call a hook that we sent a message in a stream? (combine with flush hook?) } - _, err = resp.Write(trailer.Bytes()) - if err != nil { - // TODO: call error hook? - err = wrapErr(err, "failed to write trailer") - respStream.End(err) - } + callResponseSent(ctx, s.hooks) } func (s *haberdasherServer) ServiceDescriptor() ([]byte, int) { @@ -1019,7 +1031,31 @@ func (r protoStreamReader) Read(msg proto.Message) error { trailerTag = (2 << 3) | 2 ) - if tag != msgTag && tag != trailerTag { + if tag == trailerTag { + // This is a trailer (twirp error), read it and then close the client + defer r.c.Close() + // Read the length delimiter + l, err := binary.ReadUvarint(r.r) + if err != nil { + return clientError("unable to read trailer's length delimiter", err) + } + sb := new(strings.Builder) + sb.Grow(int(l)) + _, err = io.Copy(sb, r.r) + if err != nil { + return clientError("unable to read trailer", err) + } + if sb.String() == "EOF" { + return io.EOF + } + var tj twerrJSON + if err = json.Unmarshal([]byte(sb.String()), &tj); err != nil { + return clientError("unable to decode trailer", err) + } + return tj.toTwirpError() + } + + if tag != msgTag { return fmt.Errorf("invalid field tag: %v", tag) } @@ -1031,45 +1067,17 @@ func (r protoStreamReader) Read(msg proto.Message) error { if int(l) < 0 || int(l) > r.maxSize { return io.ErrShortBuffer } - if tag == msgTag { - buf := make([]byte, int(l)) - - // Go ahead and read a message. - _, err = io.ReadFull(r.r, buf) - if err != nil { - return err - } - - err = proto.Unmarshal(buf, msg) - if err != nil { - return err - } - return nil - } - - // This is a trailer, read it and then close the client - defer r.c.Close() buf := make([]byte, int(l)) - _, err = io.ReadFull(r.r, buf) - if err != nil { + + // Go ahead and read a message. + if _, err = io.ReadFull(r.r, buf); err != nil { return err } - // Put the length back in front of the trailer so it can be decoded - buf = append(proto.EncodeVarint(l), buf...) - var trailer string - trailer, err = proto.NewBuffer(buf).DecodeStringBytes() - if err != nil { - return clientError("failed to read stream trailer", err) - } - if trailer == "EOF" { - return io.EOF - } - var tj twerrJSON - if err = json.Unmarshal([]byte(trailer), &tj); err != nil { - return clientError("unable to decode stream trailer", err) + if err = proto.Unmarshal(buf, msg); err != nil { + return err } - return tj.toTwirpError() + return nil } type jsonStreamReader struct { diff --git a/internal/twirptest/gogo_compat/service.twirp.go b/internal/twirptest/gogo_compat/service.twirp.go index 341a2bd3..7630c317 100644 --- a/internal/twirptest/gogo_compat/service.twirp.go +++ b/internal/twirptest/gogo_compat/service.twirp.go @@ -255,38 +255,37 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Msg + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -298,12 +297,11 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Msg and nil error while calling Send. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Msg and nil error while calling Send. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -312,12 +310,15 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } diff --git a/internal/twirptest/importable/importable.twirp.go b/internal/twirptest/importable/importable.twirp.go index 3c23ec05..69cbfc75 100644 --- a/internal/twirptest/importable/importable.twirp.go +++ b/internal/twirptest/importable/importable.twirp.go @@ -254,38 +254,37 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Msg + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -297,12 +296,11 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Msg and nil error while calling Send. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Msg and nil error while calling Send. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -311,12 +309,15 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } diff --git a/internal/twirptest/importer/importer.twirp.go b/internal/twirptest/importer/importer.twirp.go index d3a19d44..7f020bb8 100644 --- a/internal/twirptest/importer/importer.twirp.go +++ b/internal/twirptest/importer/importer.twirp.go @@ -263,38 +263,37 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(twirp_internal_twirptest_importable.Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *twirp_internal_twirptest_importable.Msg + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -306,12 +305,11 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *twirp_internal_twirptest_importable.Msg and nil error while calling Send. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil twirp_internal_twirptest_importable.Msg and nil error while calling Send. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -320,12 +318,15 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } @@ -351,15 +352,7 @@ func (s *svc2Server) serveStreamJSON(ctx context.Context, resp http.ResponseWrit } func (s *svc2Server) serveStreamProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - - ctx = ctxsetters.WithMethodName(ctx, "Stream") - ctx, err = callRequestRouted(ctx, s.hooks) - if err != nil { - writeProtoError(err) - return - } - - resp.Header().Set("Content-Type", "application/protobuf") + s.writeError(ctx, resp, twirp.InternalError("rpc type \"bidirectional\" is not implemented")) } func (s *svc2Server) ServiceDescriptor() ([]byte, int) { @@ -850,7 +843,31 @@ func (r protoStreamReader) Read(msg proto.Message) error { trailerTag = (2 << 3) | 2 ) - if tag != msgTag && tag != trailerTag { + if tag == trailerTag { + // This is a trailer (twirp error), read it and then close the client + defer r.c.Close() + // Read the length delimiter + l, err := binary.ReadUvarint(r.r) + if err != nil { + return clientError("unable to read trailer's length delimiter", err) + } + sb := new(strings.Builder) + sb.Grow(int(l)) + _, err = io.Copy(sb, r.r) + if err != nil { + return clientError("unable to read trailer", err) + } + if sb.String() == "EOF" { + return io.EOF + } + var tj twerrJSON + if err = json.Unmarshal([]byte(sb.String()), &tj); err != nil { + return clientError("unable to decode trailer", err) + } + return tj.toTwirpError() + } + + if tag != msgTag { return fmt.Errorf("invalid field tag: %v", tag) } @@ -862,45 +879,17 @@ func (r protoStreamReader) Read(msg proto.Message) error { if int(l) < 0 || int(l) > r.maxSize { return io.ErrShortBuffer } - if tag == msgTag { - buf := make([]byte, int(l)) - - // Go ahead and read a message. - _, err = io.ReadFull(r.r, buf) - if err != nil { - return err - } - - err = proto.Unmarshal(buf, msg) - if err != nil { - return err - } - return nil - } - - // This is a trailer, read it and then close the client - defer r.c.Close() buf := make([]byte, int(l)) - _, err = io.ReadFull(r.r, buf) - if err != nil { + + // Go ahead and read a message. + if _, err = io.ReadFull(r.r, buf); err != nil { return err } - // Put the length back in front of the trailer so it can be decoded - buf = append(proto.EncodeVarint(l), buf...) - var trailer string - trailer, err = proto.NewBuffer(buf).DecodeStringBytes() - if err != nil { - return clientError("failed to read stream trailer", err) - } - if trailer == "EOF" { - return io.EOF - } - var tj twerrJSON - if err = json.Unmarshal([]byte(trailer), &tj); err != nil { - return clientError("unable to decode stream trailer", err) + if err = proto.Unmarshal(buf, msg); err != nil { + return err } - return tj.toTwirpError() + return nil } type jsonStreamReader struct { diff --git a/internal/twirptest/multiple/multiple1.twirp.go b/internal/twirptest/multiple/multiple1.twirp.go index afe9bd5e..1515f5d1 100644 --- a/internal/twirptest/multiple/multiple1.twirp.go +++ b/internal/twirptest/multiple/multiple1.twirp.go @@ -255,38 +255,37 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Msg1) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Msg1 + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -298,12 +297,11 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Msg1 and nil error while calling Send. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Msg1 and nil error while calling Send. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -312,12 +310,15 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } diff --git a/internal/twirptest/multiple/multiple2.twirp.go b/internal/twirptest/multiple/multiple2.twirp.go index 68191aae..3fabd56e 100644 --- a/internal/twirptest/multiple/multiple2.twirp.go +++ b/internal/twirptest/multiple/multiple2.twirp.go @@ -263,38 +263,37 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Msg2) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Msg2 + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -306,12 +305,11 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Msg2 and nil error while calling Send. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Msg2 and nil error while calling Send. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -320,12 +318,15 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } @@ -411,38 +412,37 @@ func (s *svc2Server) serveSamePackageProtoImportJSON(ctx context.Context, resp h func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "SamePackageProtoImport") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Msg1) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Msg1 + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -454,12 +454,11 @@ func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, re return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Msg1 and nil error while calling SamePackageProtoImport. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Msg1 and nil error while calling SamePackageProtoImport. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -468,12 +467,15 @@ func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, re } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } diff --git a/internal/twirptest/no_package_name/no_package_name.twirp.go b/internal/twirptest/no_package_name/no_package_name.twirp.go index 49cc2080..289b2f00 100644 --- a/internal/twirptest/no_package_name/no_package_name.twirp.go +++ b/internal/twirptest/no_package_name/no_package_name.twirp.go @@ -251,38 +251,37 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Msg + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -294,12 +293,11 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Msg and nil error while calling Send. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Msg and nil error while calling Send. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -308,12 +306,15 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } diff --git a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go index 81af039a..2fee1ec1 100644 --- a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go +++ b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go @@ -253,38 +253,37 @@ func (s *svc2Server) serveMethodJSON(ctx context.Context, resp http.ResponseWrit func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Method") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(no_package_name.Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *no_package_name.Msg + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -296,12 +295,11 @@ func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.Response return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *no_package_name.Msg and nil error while calling Method. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil no_package_name.Msg and nil error while calling Method. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -310,12 +308,15 @@ func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.Response } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } diff --git a/internal/twirptest/proto/proto.twirp.go b/internal/twirptest/proto/proto.twirp.go index b440297b..ba09d979 100644 --- a/internal/twirptest/proto/proto.twirp.go +++ b/internal/twirptest/proto/proto.twirp.go @@ -254,38 +254,37 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Send") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Msg) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Msg + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -297,12 +296,11 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Msg and nil error while calling Send. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Msg and nil error while calling Send. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -311,12 +309,15 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index 5bf4192f..ffd56dc0 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -255,38 +255,37 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "MakeHat") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Size) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Hat + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -298,12 +297,11 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Hat and nil error while calling MakeHat. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Size and nil error while calling MakeHat. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -312,12 +310,15 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } @@ -400,6 +401,12 @@ func (c *streamerProtobufClient) Download(ctx context.Context, in *Req) (RespStr if err != nil { return nil, clientError("failed to do request", err) } + if err = ctx.Err(); err != nil { + return nil, clientError("aborted because context was done", err) + } + if resp.StatusCode != 200 { + return nil, errorFromResponse(resp) + } return &protoRespStreamReader{ prs: protoStreamReader{ @@ -471,6 +478,12 @@ func (c *streamerJSONClient) Download(ctx context.Context, in *Req) (RespStream, if err != nil { return nil, clientError("failed to do request", err) } + if err = ctx.Err(); err != nil { + return nil, clientError("aborted because context was done", err) + } + if resp.StatusCode != 200 { + return nil, errorFromResponse(resp) + } jrs, err := newJSONStreamReader(resp.Body) if err != nil { @@ -632,38 +645,37 @@ func (s *streamerServer) serveTransactJSON(ctx context.Context, resp http.Respon func (s *streamerServer) serveTransactProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { var err error - writeProtoError := func(err error) { - s.writeError(ctx, resp, err) - } - ctx = ctxsetters.WithMethodName(ctx, "Transact") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Req) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method var respContent *Resp + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() @@ -675,12 +687,11 @@ func (s *streamerServer) serveTransactProtobuf(ctx context.Context, resp http.Re return } if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *Resp and nil error while calling Transact. nil responses are not supported")) + s.writeError(ctx, resp, twirp.InternalError("received a nil Req and nil error while calling Transact. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) - respBytes, err := proto.Marshal(respContent) if err != nil { err = wrapErr(err, "failed to marshal proto response") @@ -689,12 +700,15 @@ func (s *streamerServer) serveTransactProtobuf(ctx context.Context, resp http.Re } ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) twerr := twirp.NewError(twirp.Unknown, msg) callError(ctx, s.hooks, twerr) } + callResponseSent(ctx, s.hooks) } @@ -720,15 +734,7 @@ func (s *streamerServer) serveUploadJSON(ctx context.Context, resp http.Response } func (s *streamerServer) serveUploadProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - - ctx = ctxsetters.WithMethodName(ctx, "Upload") - ctx, err = callRequestRouted(ctx, s.hooks) - if err != nil { - writeProtoError(err) - return - } - - resp.Header().Set("Content-Type", "application/protobuf") + s.writeError(ctx, resp, twirp.InternalError("rpc type \"upload\" is not implemented")) } func (s *streamerServer) serveDownload(ctx context.Context, resp http.ResponseWriter, req *http.Request) { @@ -753,89 +759,90 @@ func (s *streamerServer) serveDownloadJSON(ctx context.Context, resp http.Respon } func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - var ( - err error - respStream RespStream - ) - - // Prepare trailer - trailer := proto.NewBuffer(nil) - _ = trailer.EncodeVarint((2 << 3) | 2) // field tag - writeProtoError := func(err error) { - // JSON encode err as twirp err - // TODO: figure out what to do about updating context and headers - if err == io.EOF { - trailer.EncodeStringBytes("EOF") - return - } - twerr, ok := err.(twirp.Error) - if !ok { - twerr = twirp.InternalErrorWith(err) - } - _ = trailer.EncodeStringBytes( - string(marshalErrorToJSON(twerr)), - ) - } - defer func() { // Send trailer - _, err = resp.Write(trailer.Bytes()) - if err != nil { - // TODO: call error hook? - err = wrapErr(err, "failed to write trailer") - respStream.End(err) - } - }() - + var err error ctx = ctxsetters.WithMethodName(ctx, "Download") ctx, err = callRequestRouted(ctx, s.hooks) if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - resp.Header().Set("Content-Type", "application/protobuf") buf, err := ioutil.ReadAll(req.Body) if err != nil { err = wrapErr(err, "failed to read request body") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } reqContent := new(Req) if err = proto.Unmarshal(buf, reqContent); err != nil { err = wrapErr(err, "failed to parse request proto") - writeProtoError(twirp.InternalErrorWith(err)) + s.writeError(ctx, resp, twirp.InternalErrorWith(err)) return } // Call service method + var respContent RespStream + respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { // In case of a panic, serve a 500 error and then panic. if r := recover(); r != nil { - writeProtoError(twirp.InternalError("Internal service panic")) + s.writeError(ctx, resp, twirp.InternalError("Internal service panic")) + if canFlush { + respFlusher.Flush() + } panic(r) } }() - respStream, err = s.Download(ctx, reqContent) + respContent, err = s.Download(ctx, reqContent) }() if err != nil { - writeProtoError(err) + s.writeError(ctx, resp, err) return } - - if respStream == nil { - writeProtoError(twirp.InternalError("received a nil Req and nil error while calling Download. nil responses are not supported")) + if respContent == nil { + s.writeError(ctx, resp, twirp.InternalError("received a nil Req and nil error while calling Download. nil responses are not supported")) return } ctx = callResponsePrepared(ctx, s.hooks) + ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") + resp.WriteHeader(http.StatusOK) + + // Prepare trailer + trailer := proto.NewBuffer(nil) + _ = trailer.EncodeVarint((2 << 3) | 2) // field tag + writeTrailer := func(err error) { + if err == io.EOF { + trailer.EncodeStringBytes("EOF") + return + } + // Write trailer as json-encoded twirp err + twerr, ok := err.(twirp.Error) + if !ok { + twerr = twirp.InternalErrorWith(err) + } + statusCode := twirp.ServerHTTPStatusFromErrorCode(twerr.Code()) + ctx = ctxsetters.WithStatusCode(ctx, statusCode) + ctx = callError(ctx, s.hooks, twerr) + if encodeErr := trailer.EncodeStringBytes(string(marshalErrorToJSON(twerr))); encodeErr != nil { + _ = trailer.EncodeStringBytes("{\"code\":\"" + string(twirp.Internal) + "\",\"msg\":\"There was an error but it could not be serialized into JSON\"}") // fallback + } + _, writeErr := resp.Write(trailer.Bytes()) + if writeErr != nil { + // Ignored, for the same reason as in the writeError func + _ = writeErr + } + respContent.End(twerr) + } - respFlusher, canFlush := resp.(http.Flusher) messages := proto.NewBuffer(nil) for { - msg, err := respStream.Next(ctx) + msg, err := respContent.Next(ctx) if err != nil { - writeProtoError(err) + writeTrailer(err) break } @@ -844,31 +851,29 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re err = messages.EncodeMessage(msg) if err != nil { err = wrapErr(err, "failed to marshal proto message") - respStream.End(err) + writeTrailer(err) break } _, err = resp.Write(messages.Bytes()) if err != nil { err = wrapErr(err, "failed to send proto message") - respStream.End(err) + writeTrailer(err) // likely to fail on write, but try anyway to ensure ctx gets error code set for responseSent hook break } if canFlush { - // TODO: come up with a batching scheme to improve performance under high load + // TODO: Come up with a batching scheme to improve performance under high load + // and/or provide a hook for the respStream to control flushing the response. + // Flushing after each message dramatically reduces high-load throughput -- + // difference can be more than 10x based on initial experiments respFlusher.Flush() } - // TODO: Call a hook that we sent a message in a stream? + // TODO: Call a hook that we sent a message in a stream? (combine with flush hook?) } - _, err = resp.Write(trailer.Bytes()) - if err != nil { - // TODO: call error hook? - err = wrapErr(err, "failed to write trailer") - respStream.End(err) - } + callResponseSent(ctx, s.hooks) } func (s *streamerServer) serveCommunicate(ctx context.Context, resp http.ResponseWriter, req *http.Request) { @@ -893,15 +898,7 @@ func (s *streamerServer) serveCommunicateJSON(ctx context.Context, resp http.Res } func (s *streamerServer) serveCommunicateProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - - ctx = ctxsetters.WithMethodName(ctx, "Communicate") - ctx, err = callRequestRouted(ctx, s.hooks) - if err != nil { - writeProtoError(err) - return - } - - resp.Header().Set("Content-Type", "application/protobuf") + s.writeError(ctx, resp, twirp.InternalError("rpc type \"bidirectional\" is not implemented")) } func (s *streamerServer) ServiceDescriptor() ([]byte, int) { @@ -1429,7 +1426,31 @@ func (r protoStreamReader) Read(msg proto.Message) error { trailerTag = (2 << 3) | 2 ) - if tag != msgTag && tag != trailerTag { + if tag == trailerTag { + // This is a trailer (twirp error), read it and then close the client + defer r.c.Close() + // Read the length delimiter + l, err := binary.ReadUvarint(r.r) + if err != nil { + return clientError("unable to read trailer's length delimiter", err) + } + sb := new(strings.Builder) + sb.Grow(int(l)) + _, err = io.Copy(sb, r.r) + if err != nil { + return clientError("unable to read trailer", err) + } + if sb.String() == "EOF" { + return io.EOF + } + var tj twerrJSON + if err = json.Unmarshal([]byte(sb.String()), &tj); err != nil { + return clientError("unable to decode trailer", err) + } + return tj.toTwirpError() + } + + if tag != msgTag { return fmt.Errorf("invalid field tag: %v", tag) } @@ -1441,45 +1462,17 @@ func (r protoStreamReader) Read(msg proto.Message) error { if int(l) < 0 || int(l) > r.maxSize { return io.ErrShortBuffer } - if tag == msgTag { - buf := make([]byte, int(l)) - - // Go ahead and read a message. - _, err = io.ReadFull(r.r, buf) - if err != nil { - return err - } - - err = proto.Unmarshal(buf, msg) - if err != nil { - return err - } - return nil - } - - // This is a trailer, read it and then close the client - defer r.c.Close() buf := make([]byte, int(l)) - _, err = io.ReadFull(r.r, buf) - if err != nil { + + // Go ahead and read a message. + if _, err = io.ReadFull(r.r, buf); err != nil { return err } - // Put the length back in front of the trailer so it can be decoded - buf = append(proto.EncodeVarint(l), buf...) - var trailer string - trailer, err = proto.NewBuffer(buf).DecodeStringBytes() - if err != nil { - return clientError("failed to read stream trailer", err) - } - if trailer == "EOF" { - return io.EOF - } - var tj twerrJSON - if err = json.Unmarshal([]byte(trailer), &tj); err != nil { - return clientError("unable to decode stream trailer", err) + if err = proto.Unmarshal(buf, msg); err != nil { + return err } - return tj.toTwirpError() + return nil } type jsonStreamReader struct { diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index 8194d9d7..c098ec1e 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -769,7 +769,31 @@ func (t *twirp) generateStreamUtils() { t.P(` trailerTag = (2 << 3) | 2`) t.P(` )`) t.P(``) - t.P(` if tag != msgTag && tag != trailerTag {`) + t.P(` if tag == trailerTag {`) + t.P(` // This is a trailer (twirp error), read it and then close the client`) + t.P(` defer r.c.Close()`) + t.P(` // Read the length delimiter`) + t.P(` l, err := binary.ReadUvarint(r.r)`) + t.P(` if err != nil {`) + t.P(` return clientError("unable to read trailer's length delimiter", err)`) + t.P(` }`) + t.P(` sb := new(`, t.pkgs["strings"], `.Builder)`) + t.P(` sb.Grow(int(l))`) + t.P(` _, err = `, t.pkgs["io"], `.Copy(sb, r.r)`) + t.P(` if err != nil {`) + t.P(` return clientError("unable to read trailer", err)`) + t.P(` }`) + t.P(` if sb.String() == "EOF" {`) + t.P(` return `, t.pkgs["io"], `.EOF`) + t.P(` }`) + t.P(` var tj twerrJSON`) + t.P(` if err = `, t.pkgs["json"], `.Unmarshal([]byte(sb.String()), &tj); err != nil {`) + t.P(` return clientError("unable to decode trailer", err)`) + t.P(` }`) + t.P(` return tj.toTwirpError()`) + t.P(` }`) + t.P(``) + t.P(` if tag != msgTag {`) t.P(` return `, t.pkgs["fmt"], `.Errorf("invalid field tag: %v", tag)`) t.P(` }`) t.P(``) @@ -781,45 +805,17 @@ func (t *twirp) generateStreamUtils() { t.P(` if int(l) < 0 || int(l) > r.maxSize {`) t.P(` return `, t.pkgs["io"], `.ErrShortBuffer`) t.P(` }`) - t.P(` if tag == msgTag {`) - t.P(` buf := make([]byte, int(l))`) - t.P() - t.P(` // Go ahead and read a message.`) - t.P(` _, err = io.ReadFull(r.r, buf)`) - t.P(` if err != nil {`) - t.P(` return err`) - t.P(` }`) - t.P() - t.P(` err = proto.Unmarshal(buf, msg)`) - t.P(` if err != nil {`) - t.P(` return err`) - t.P(` }`) - t.P(` return nil`) - t.P(` }`) - t.P() - t.P(` // This is a trailer, read it and then close the client`) - t.P(` defer r.c.Close()`) t.P(` buf := make([]byte, int(l))`) - t.P(` _, err = `, t.pkgs["io"], `.ReadFull(r.r, buf)`) - t.P(` if err != nil {`) + t.P() + t.P(` // Go ahead and read a message.`) + t.P(` if _, err = `, t.pkgs["io"], `.ReadFull(r.r, buf); err != nil {`) t.P(` return err`) t.P(` }`) t.P() - t.P(` // Put the length back in front of the trailer so it can be decoded`) - t.P(` buf = append(proto.EncodeVarint(l), buf...)`) - t.P(` var trailer string`) - t.P(` trailer, err = `, t.pkgs["proto"], `.NewBuffer(buf).DecodeStringBytes()`) - t.P(` if err != nil {`) - t.P(` return clientError("failed to read stream trailer", err)`) - t.P(` }`) - t.P(` if trailer == "EOF" {`) - t.P(` return io.EOF`) - t.P(` }`) - t.P(` var tj twerrJSON`) - t.P(` if err = json.Unmarshal([]byte(trailer), &tj); err != nil {`) - t.P(` return clientError("unable to decode stream trailer", err)`) + t.P(` if err = `, t.pkgs["proto"], `.Unmarshal(buf, msg); err != nil {`) + t.P(` return err`) t.P(` }`) - t.P(` return tj.toTwirpError()`) + t.P(` return nil`) t.P(`}`) t.P() @@ -899,7 +895,7 @@ func (t *twirp) generateStreamUtils() { t.P(` }`) t.P(``) t.P(` if tj.Code == "stream_complete" {`) - t.P(` return io.EOF`) + t.P(` return `, t.pkgs["io"], `.EOF`) t.P(` }`) t.P(``) t.P(` return tj.toTwirpError()`) @@ -1099,6 +1095,12 @@ func (t *twirp) generateClient(name string, file *descriptor.FileDescriptorProto t.P(` if err != nil {`) t.P(` return nil, clientError("failed to do request", err)`) t.P(` }`) + t.P(` if err = ctx.Err(); err != nil {`) + t.P(` return nil, clientError("aborted because context was done", err)`) + t.P(` }`) + t.P(` if resp.StatusCode != 200 {`) + t.P(` return nil, errorFromResponse(resp)`) + t.P(` }`) t.P(``) if name == "Protobuf" { t.P(` return &proto`, withoutPackageName(outputType), `StreamReader{`) @@ -1330,148 +1332,128 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript servStruct := serviceStruct(service) methName := stringutils.CamelCase(method.GetName()) rpcType := methodRPCType(method) + var respType string + if rpcType == download || rpcType == bidirectional { + respType = t.methodOutputType(method) + } else { + respType = `*` + t.methodOutputType(method) + } t.P(`func (s *`, servStruct, `) serve`, methName, `Protobuf(ctx `, t.pkgs["context"], `.Context, resp `, t.pkgs["http"], `.ResponseWriter, req *`, t.pkgs["http"], `.Request) {`) - - if rpcType == unary { - t.P(` var err error`) - t.P(` writeProtoError := func(err error) {`) - t.P(` s.writeError(ctx, resp, err)`) - t.P(` }`) - } - if rpcType == download { - t.P(` var (`) - t.P(` err error`) - t.P(` respStream `, t.methodOutputType(method)) - t.P(` )`) - t.P() - t.P(` // Prepare trailer`) - t.P(` trailer := proto.NewBuffer(nil)`) - t.P(` _ = trailer.EncodeVarint((2 << 3) | 2) // field tag`) - t.P(` writeProtoError := func(err error) {`) - t.P(` // JSON encode err as twirp err`) - t.P(` // TODO: figure out what to do about updating context and headers`) - t.P(` if err == io.EOF {`) - t.P(` trailer.EncodeStringBytes("EOF")`) - t.P(` return`) - t.P(` }`) - t.P(` twerr, ok := err.(twirp.Error)`) - t.P(` if !ok {`) - t.P(` twerr = twirp.InternalErrorWith(err)`) - t.P(` }`) - t.P(` _ = trailer.EncodeStringBytes(`) - t.P(` string(marshalErrorToJSON(twerr)),`) - t.P(` )`) - t.P(` }`) - t.P(` defer func() { // Send trailer`) - t.P(` _, err = resp.Write(trailer.Bytes())`) - t.P(` if err != nil {`) - t.P(` // TODO: call error hook?`) - t.P(` err = wrapErr(err, "failed to write trailer")`) - t.P(` respStream.End(err)`) - t.P(` }`) - t.P(` }()`) + if rpcType != unary && rpcType != download { + t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalError("rpc type \"`, string(rpcType), `\" is not implemented"))`) + t.P(`}`) + t.P(``) + return } - - t.P() + t.P(` var err error`) t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithMethodName(ctx, "`, methName, `")`) t.P(` ctx, err = callRequestRouted(ctx, s.hooks)`) t.P(` if err != nil {`) - t.P(` writeProtoError(err)`) + t.P(` s.writeError(ctx, resp, err)`) t.P(` return`) t.P(` }`) + t.P(``) - t.P() - t.P(` resp.Header().Set("Content-Type", "application/protobuf")`) + // Read the request + t.P(` buf, err := `, t.pkgs["ioutil"], `.ReadAll(req.Body)`) + t.P(` if err != nil {`) + t.P(` err = wrapErr(err, "failed to read request body")`) + t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalErrorWith(err))`) + t.P(` return`) + t.P(` }`) + t.P(` reqContent := new(`, t.methodInputType(method), `)`) + t.P(` if err = `, t.pkgs["proto"], `.Unmarshal(buf, reqContent); err != nil {`) + t.P(` err = wrapErr(err, "failed to parse request proto")`) + t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalErrorWith(err))`) + t.P(` return`) + t.P(` }`) + t.P(``) - if rpcType == unary || rpcType == download { - t.P(` buf, err := `, t.pkgs["ioutil"], `.ReadAll(req.Body)`) - t.P(` if err != nil {`) - t.P(` err = wrapErr(err, "failed to read request body")`) - t.P(` writeProtoError(`, t.pkgs["twirp"], `.InternalErrorWith(err))`) - t.P(` return`) - t.P(` }`) - t.P(` reqContent := new(`, t.methodInputType(method), `)`) - t.P(` if err = `, t.pkgs["proto"], `.Unmarshal(buf, reqContent); err != nil {`) - t.P(` err = wrapErr(err, "failed to parse request proto")`) - t.P(` writeProtoError(`, t.pkgs["twirp"], `.InternalErrorWith(err))`) - t.P(` return`) - t.P(` }`) - t.P() - } - if rpcType == unary { - t.P(` // Call service method`) - t.P(` var respContent *`, t.methodOutputType(method)) - t.P(` func() {`) - t.P(` defer func() {`) - t.P(` // In case of a panic, serve a 500 error and then panic.`) - t.P(` if r := recover(); r != nil {`) - t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalError("Internal service panic"))`) - t.P(` panic(r)`) - t.P(` }`) - t.P(` }()`) - t.P(` respContent, err = s.`, methName, `(ctx, reqContent)`) - t.P(` }()`) - t.P() - t.P(` if err != nil {`) - t.P(` s.writeError(ctx, resp, err)`) - t.P(` return`) - t.P(` }`) - t.P(` if respContent == nil {`) - t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalError("received a nil *`, t.goTypeName(method.GetOutputType()), ` and nil error while calling `, methName, `. nil responses are not supported"))`) - t.P(` return`) - t.P(` }`) - t.P() - t.P(` ctx = callResponsePrepared(ctx, s.hooks)`) - t.P() - t.P(` respBytes, err := `, t.pkgs["proto"], `.Marshal(respContent)`) - t.P(` if err != nil {`) - t.P(` err = wrapErr(err, "failed to marshal proto response")`) - t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalErrorWith(err))`) - t.P(` return`) - t.P(` }`) - t.P() - t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithStatusCode(ctx, `, t.pkgs["http"], `.StatusOK)`) - t.P(` resp.WriteHeader(`, t.pkgs["http"], `.StatusOK)`) - t.P(` if n, err := resp.Write(respBytes); err != nil {`) - t.P(` msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error())`) - t.P(` twerr := `, t.pkgs["twirp"], `.NewError(`, t.pkgs["twirp"], `.Unknown, msg)`) - t.P(` callError(ctx, s.hooks, twerr)`) - t.P(` }`) - t.P(` callResponseSent(ctx, s.hooks)`) - } + // Prepare the response + t.P(` // Call service method`) + t.P(` var respContent `, respType) + t.P(` respFlusher, canFlush := resp.(`, t.pkgs["http"], `.Flusher)`) + t.P(` func() {`) + t.P(` defer func() {`) + t.P(` // In case of a panic, serve a 500 error and then panic.`) + t.P(` if r := recover(); r != nil {`) + t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalError("Internal service panic"))`) + t.P(` if canFlush {`) + t.P(` respFlusher.Flush()`) + t.P(` }`) + t.P(` panic(r)`) + t.P(` }`) + t.P(` }()`) + t.P(` respContent, err = s.`, methName, `(ctx, reqContent)`) + t.P(` }()`) + t.P(``) + t.P(` if err != nil {`) + t.P(` s.writeError(ctx, resp, err)`) + t.P(` return`) + t.P(` }`) + t.P(` if respContent == nil {`) + t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalError("received a nil `, t.methodInputType(method), ` and nil error while calling `, methName, `. nil responses are not supported"))`) + t.P(` return`) + t.P(` }`) + t.P(``) + t.P(` ctx = callResponsePrepared(ctx, s.hooks)`) - if rpcType == download { - t.P(` // Call service method`) - t.P(` func() {`) - t.P(` defer func() {`) - t.P(` // In case of a panic, serve a 500 error and then panic.`) - t.P(` if r := recover(); r != nil {`) - t.P(` writeProtoError(`, t.pkgs["twirp"], `.InternalError("Internal service panic"))`) - t.P(` panic(r)`) - t.P(` }`) - t.P(` }()`) - t.P(` respStream, err = s.`, methName, `(ctx, reqContent)`) - t.P(` }()`) - t.P(``) + // Send the response + if rpcType == unary { + t.P(` respBytes, err := `, t.pkgs["proto"], `.Marshal(respContent)`) t.P(` if err != nil {`) - t.P(` writeProtoError(err)`) + t.P(` err = wrapErr(err, "failed to marshal proto response")`) + t.P(` s.writeError(ctx, resp, `, t.pkgs["twirp"], `.InternalErrorWith(err))`) t.P(` return`) t.P(` }`) t.P(``) - t.P(` if respStream == nil {`) - t.P(` writeProtoError(`, t.pkgs["twirp"], `.InternalError("received a nil `, t.methodInputType(method), ` and nil error while calling `, methName, `. nil responses are not supported"))`) - t.P(` return`) + } + t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithStatusCode(ctx, `, t.pkgs["http"], `.StatusOK)`) + t.P(` resp.Header().Set("Content-Type", "application/protobuf")`) + t.P(` resp.WriteHeader(`, t.pkgs["http"], `.StatusOK)`) + t.P(``) + switch rpcType { + case unary: + t.P(` if n, err := resp.Write(respBytes); err != nil {`) + t.P(` msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error())`) + t.P(` twerr := `, t.pkgs["twirp"], `.NewError(`, t.pkgs["twirp"], `.Unknown, msg)`) + t.P(` callError(ctx, s.hooks, twerr)`) t.P(` }`) t.P(``) - t.P(` ctx = callResponsePrepared(ctx, s.hooks)`) + case download: + t.P(` // Prepare trailer`) + t.P(` trailer := `, t.pkgs["proto"], `.NewBuffer(nil)`) + t.P(` _ = trailer.EncodeVarint((2 << 3) | 2) // field tag`) + t.P(` writeTrailer := func(err error) {`) + t.P(` if err == `, t.pkgs["io"], `.EOF {`) + t.P(` trailer.EncodeStringBytes("EOF")`) + t.P(` return`) + t.P(` }`) + t.P(` // Write trailer as json-encoded twirp err`) + t.P(` twerr, ok := err.(twirp.Error)`) + t.P(` if !ok {`) + t.P(` twerr = twirp.InternalErrorWith(err)`) + t.P(` }`) + t.P(` statusCode := `, t.pkgs["twirp"], `.ServerHTTPStatusFromErrorCode(twerr.Code())`) + t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithStatusCode(ctx, statusCode)`) + t.P(` ctx = callError(ctx, s.hooks, twerr)`) + t.P(` if encodeErr := trailer.EncodeStringBytes(string(marshalErrorToJSON(twerr))); encodeErr != nil {`) + t.P(` _ = trailer.EncodeStringBytes("{\"code\":\"" + string(`, t.pkgs["twirp"], `.Internal) + "\",\"msg\":\"There was an error but it could not be serialized into JSON\"}") // fallback`) + t.P(` }`) + t.P(` _, writeErr := resp.Write(trailer.Bytes())`) + t.P(` if writeErr != nil {`) + t.P(` // Ignored, for the same reason as in the writeError func`) + t.P(` _ = writeErr`) + t.P(` }`) + t.P(` respContent.End(twerr)`) + t.P(` }`) t.P(``) - t.P(` respFlusher, canFlush := resp.(http.Flusher)`) t.P(` messages := `, t.pkgs["proto"], `.NewBuffer(nil)`) t.P(` for {`) - t.P(` msg, err := respStream.Next(ctx)`) + t.P(` msg, err := respContent.Next(ctx)`) t.P(` if err != nil {`) - t.P(` writeProtoError(err)`) + t.P(` writeTrailer(err)`) t.P(` break`) t.P(` }`) t.P(``) @@ -1480,26 +1462,31 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` err = messages.EncodeMessage(msg)`) t.P(` if err != nil {`) t.P(` err = wrapErr(err, "failed to marshal proto message")`) - t.P(` respStream.End(err)`) + t.P(` writeTrailer(err)`) t.P(` break`) t.P(` }`) t.P(``) t.P(` _, err = resp.Write(messages.Bytes())`) t.P(` if err != nil {`) t.P(` err = wrapErr(err, "failed to send proto message")`) - t.P(` respStream.End(err)`) + t.P(` writeTrailer(err) // likely to fail on write, but try anyway to ensure ctx gets error code set for responseSent hook`) t.P(` break`) t.P(` }`) t.P(``) t.P(` if canFlush {`) - t.P(` // TODO: come up with a batching scheme to improve performance under high load`) + t.P(` // TODO: Come up with a batching scheme to improve performance under high load`) + t.P(` // and/or provide a hook for the respStream to control flushing the response.`) + t.P(` // Flushing after each message dramatically reduces high-load throughput --`) + t.P(` // difference can be more than 10x based on initial experiments`) t.P(` respFlusher.Flush()`) t.P(` }`) t.P(``) - t.P(` // TODO: Call a hook that we sent a message in a stream?`) + t.P(` // TODO: Call a hook that we sent a message in a stream? (combine with flush hook?)`) t.P(` }`) + t.P(``) } + t.P(` callResponseSent(ctx, s.hooks)`) t.P(`}`) t.P() } From 6547c316bdd55e56d27b07eac68559b92cb2a2fc Mon Sep 17 00:00:00 2001 From: Mike Lemmon Date: Mon, 4 Jun 2018 14:34:26 -0700 Subject: [PATCH 4/7] generator: Use bytes.Buffer instead of strings.Builder for compatibility with older versions of go --- example/service.twirp.go | 2 +- internal/twirptest/importer/importer.twirp.go | 2 +- internal/twirptest/service.twirp.go | 2 +- protoc-gen-twirp/generator.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/example/service.twirp.go b/example/service.twirp.go index 6242b2f3..5ecf4fef 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -1039,7 +1039,7 @@ func (r protoStreamReader) Read(msg proto.Message) error { if err != nil { return clientError("unable to read trailer's length delimiter", err) } - sb := new(strings.Builder) + sb := new(bytes.Buffer) sb.Grow(int(l)) _, err = io.Copy(sb, r.r) if err != nil { diff --git a/internal/twirptest/importer/importer.twirp.go b/internal/twirptest/importer/importer.twirp.go index 7f020bb8..739e772e 100644 --- a/internal/twirptest/importer/importer.twirp.go +++ b/internal/twirptest/importer/importer.twirp.go @@ -851,7 +851,7 @@ func (r protoStreamReader) Read(msg proto.Message) error { if err != nil { return clientError("unable to read trailer's length delimiter", err) } - sb := new(strings.Builder) + sb := new(bytes.Buffer) sb.Grow(int(l)) _, err = io.Copy(sb, r.r) if err != nil { diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index ffd56dc0..2f3fbe73 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -1434,7 +1434,7 @@ func (r protoStreamReader) Read(msg proto.Message) error { if err != nil { return clientError("unable to read trailer's length delimiter", err) } - sb := new(strings.Builder) + sb := new(bytes.Buffer) sb.Grow(int(l)) _, err = io.Copy(sb, r.r) if err != nil { diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index c098ec1e..d09edb08 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -777,7 +777,7 @@ func (t *twirp) generateStreamUtils() { t.P(` if err != nil {`) t.P(` return clientError("unable to read trailer's length delimiter", err)`) t.P(` }`) - t.P(` sb := new(`, t.pkgs["strings"], `.Builder)`) + t.P(` sb := new(`, t.pkgs["bytes"], `.Buffer)`) t.P(` sb.Grow(int(l))`) t.P(` _, err = `, t.pkgs["io"], `.Copy(sb, r.r)`) t.P(` if err != nil {`) From 8f561c53b45fe7a18dea4ba1c8d38fdb106ead93 Mon Sep 17 00:00:00 2001 From: Mike Lemmon Date: Mon, 4 Jun 2018 18:43:31 -0700 Subject: [PATCH 5/7] Add NewStream constructor to generated .twirp.go files * Define `OrError` struct that is a union of the `` and `error` return values from `Next` * Define `NewStream` constructor that takes a channel of `OrError` and returns an implementation of the `Stream` interface --- example/cmd/server/hat_stream_sender.go | 55 ----------------- example/cmd/server/main.go | 6 +- example/service.twirp.go | 30 ++++++++++ internal/twirptest/importer/importer.twirp.go | 30 ++++++++++ internal/twirptest/service.twirp.go | 60 +++++++++++++++++++ protoc-gen-twirp/generator.go | 34 +++++++++++ 6 files changed, 157 insertions(+), 58 deletions(-) delete mode 100644 example/cmd/server/hat_stream_sender.go diff --git a/example/cmd/server/hat_stream_sender.go b/example/cmd/server/hat_stream_sender.go deleted file mode 100644 index 171ef7f3..00000000 --- a/example/cmd/server/hat_stream_sender.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the License is -// located at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// or in the "license" file accompanying this file. This file is distributed on -// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package main - -// This file is an alternate implementation of the twirp-provided stream -// constructor originally proposed in the "As the sender" section of -// https://github.com/twitchtv/twirp/issues/70#issuecomment-361365458 - -import ( - "context" - "io" - - "github.com/twitchtv/twirp/example" -) - -type HatOrError struct { - hat *example.Hat - err error -} - -func NewHatStream(ch chan HatOrError) *hatStreamSender { - return &hatStreamSender{ch: ch} -} - -type hatStreamSender struct { - ch <-chan HatOrError -} - -func (hs *hatStreamSender) Next(ctx context.Context) (*example.Hat, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case v, open := <-hs.ch: - if !open { - return nil, io.EOF - } - if v.err != nil { - return nil, v.err - } - return v.hat, nil - } -} - -func (hs *hatStreamSender) End(err error) {} // Should anything go here? diff --git a/example/cmd/server/main.go b/example/cmd/server/main.go index 85d97168..da5f8f4d 100644 --- a/example/cmd/server/main.go +++ b/example/cmd/server/main.go @@ -56,19 +56,19 @@ func (h *randomHaberdasher) MakeHats(ctx context.Context, req *example.MakeHatsR // return nil, errTooSmall // } - ch := make(chan HatOrError, 100) // NB: the size of this buffer can make a big difference! + ch := make(chan example.HatOrError, 100) // NB: the size of this buffer can make a big difference! go func() { for ii := int32(0); ii < req.Quantity; ii++ { hat, err := newRandomHat(req.Inches) select { case <-ctx.Done(): return - case ch <- HatOrError{hat, err}: + case ch <- example.HatOrError{Hat: hat, Err: err}: } } close(ch) }() - return NewHatStream(ch), nil + return example.NewHatStream(ch), nil } func main() { diff --git a/example/service.twirp.go b/example/service.twirp.go index 5ecf4fef..af1797b9 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -588,6 +588,36 @@ func (r jsonHatStreamReader) Next(context.Context) (*Hat, error) { func (r jsonHatStreamReader) End(error) { _ = r.c.Close() } +type HatOrError struct { + Hat *Hat + Err error +} + +func NewHatStream(ch chan HatOrError) *hatStreamSender { + return &hatStreamSender{ch: ch} +} + +type hatStreamSender struct { + ch <-chan HatOrError +} + +func (ss *hatStreamSender) Next(ctx context.Context) (*Hat, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case v, open := <-ss.ch: + if !open { + return nil, io.EOF + } + if v.Err != nil { + return nil, v.Err + } + return v.Hat, nil + } +} + +func (ss *hatStreamSender) End(err error) {} + // ===== // Utils // ===== diff --git a/internal/twirptest/importer/importer.twirp.go b/internal/twirptest/importer/importer.twirp.go index 739e772e..cf6f133b 100644 --- a/internal/twirptest/importer/importer.twirp.go +++ b/internal/twirptest/importer/importer.twirp.go @@ -400,6 +400,36 @@ func (r jsonMsgStreamReader) Next(context.Context) (*twirp_internal_twirptest_im func (r jsonMsgStreamReader) End(error) { _ = r.c.Close() } +type MsgOrError struct { + Msg *twirp_internal_twirptest_importable.Msg + Err error +} + +func NewMsgStream(ch chan MsgOrError) *msgStreamSender { + return &msgStreamSender{ch: ch} +} + +type msgStreamSender struct { + ch <-chan MsgOrError +} + +func (ss *msgStreamSender) Next(ctx context.Context) (*twirp_internal_twirptest_importable.Msg, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case v, open := <-ss.ch: + if !open { + return nil, io.EOF + } + if v.Err != nil { + return nil, v.Err + } + return v.Msg, nil + } +} + +func (ss *msgStreamSender) End(err error) {} + // ===== // Utils // ===== diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index 2f3fbe73..7eff7458 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -946,6 +946,36 @@ func (r jsonReqStreamReader) Next(context.Context) (*Req, error) { func (r jsonReqStreamReader) End(error) { _ = r.c.Close() } +type ReqOrError struct { + Req *Req + Err error +} + +func NewReqStream(ch chan ReqOrError) *reqStreamSender { + return &reqStreamSender{ch: ch} +} + +type reqStreamSender struct { + ch <-chan ReqOrError +} + +func (ss *reqStreamSender) Next(ctx context.Context) (*Req, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case v, open := <-ss.ch: + if !open { + return nil, io.EOF + } + if v.Err != nil { + return nil, v.Err + } + return v.Req, nil + } +} + +func (ss *reqStreamSender) End(err error) {} + // RespStream represents a stream of Resp messages. type RespStream interface { Next(context.Context) (*Resp, error) @@ -983,6 +1013,36 @@ func (r jsonRespStreamReader) Next(context.Context) (*Resp, error) { func (r jsonRespStreamReader) End(error) { _ = r.c.Close() } +type RespOrError struct { + Resp *Resp + Err error +} + +func NewRespStream(ch chan RespOrError) *respStreamSender { + return &respStreamSender{ch: ch} +} + +type respStreamSender struct { + ch <-chan RespOrError +} + +func (ss *respStreamSender) Next(ctx context.Context) (*Resp, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case v, open := <-ss.ch: + if !open { + return nil, io.EOF + } + if v.Err != nil { + return nil, v.Err + } + return v.Resp, nil + } +} + +func (ss *respStreamSender) End(err error) {} + // ===== // Utils // ===== diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index d09edb08..e829c1d0 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -1526,6 +1526,7 @@ func (t *twirp) generateStreamType(typeName string) { t.streamTypes[typeName] = true }() + typeNameWithoutPackage := withoutPackageName(typeName) streamTypeName := withoutPackageName(typeName) + "Stream" t.P(`// `, streamTypeName, ` represents a stream of `, typeName, ` messages.`) t.P(`type `, streamTypeName, ` interface {`) @@ -1569,6 +1570,39 @@ func (t *twirp) generateStreamType(typeName string) { t.P() t.P(`func (r `, jsonReaderTypeName, `) End(error) { _ = r.c.Close() }`) t.P() + + typeOrErrorTypeName := typeNameWithoutPackage + "OrError" + t.P(`type `, typeOrErrorTypeName, ` struct {`) + t.P(` `, typeNameWithoutPackage, ` *`, typeName) + t.P(` Err error`) + t.P(`}`) + t.P() + streamSenderName := unexported(streamTypeName + "Sender") + t.P(`func New`, streamTypeName, `(ch chan `, typeOrErrorTypeName, `) *`, streamSenderName, ` {`) + t.P(` return &`, streamSenderName, `{ch: ch}`) + t.P(`}`) + t.P() + t.P(`type `, streamSenderName, ` struct {`) + t.P(` ch <-chan `, typeOrErrorTypeName) + t.P(`}`) + t.P() + t.P(`func (ss *`, streamSenderName, `) Next(ctx context.Context) (*`, typeName, `, error) {`) + t.P(` select {`) + t.P(` case <-ctx.Done():`) + t.P(` return nil, ctx.Err()`) + t.P(` case v, open := <-ss.ch:`) + t.P(` if !open {`) + t.P(` return nil, io.EOF`) + t.P(` }`) + t.P(` if v.Err != nil {`) + t.P(` return nil, v.Err`) + t.P(` }`) + t.P(` return v.`, typeNameWithoutPackage, `, nil`) + t.P(` }`) + t.P(`}`) + t.P() + t.P(`func (ss *`, streamSenderName, `) End(err error) {}`) + t.P() } // serviceMetadataVarName is the variable name used in generated code to refer From fe19e0193b0ef902003c5f441b3b715bc240ad64 Mon Sep 17 00:00:00 2001 From: Mike Lemmon Date: Tue, 19 Jun 2018 18:04:58 -0700 Subject: [PATCH 6/7] Remove RespStream type and replace with chan-based API for download-style rpcs --- example/cmd/client/main.go | 14 +- example/cmd/server/main.go | 8 +- example/cmd/server/main_test.go | 27 +- example/service.twirp.go | 169 ++++++------- internal/twirptest/importer/importer.twirp.go | 78 +----- internal/twirptest/service.twirp.go | 235 ++++++----------- protoc-gen-twirp/generator.go | 238 ++++++++---------- 7 files changed, 270 insertions(+), 499 deletions(-) diff --git a/example/cmd/client/main.go b/example/cmd/client/main.go index c15d1121..50ae514c 100644 --- a/example/cmd/client/main.go +++ b/example/cmd/client/main.go @@ -15,7 +15,6 @@ package main import ( "context" - "io" "log" "net/http" "time" @@ -75,19 +74,16 @@ func main() { khps := float64(ii-1) / took.Seconds() / 1000 log.Printf("Received %.1f kHats per second (%d hats in %f seconds)\n", khps, ii-1, took.Seconds()) } - for ; true; ii++ { // Receive all the hats - hat, err = hatStream.Next(context.Background()) - if err != nil { - if err == io.EOF { - break - } + for hatOrErr := range hatStream { + if hatOrErr.Err != nil { printResults() - log.Fatal(err) + log.Fatal(hatOrErr.Err) } if ii%printEvery == 0 { khps := float64(ii) / time.Now().Sub(reqSentAt).Seconds() / 1000 - log.Printf("\t[%4.1f khps] %6d %+v\n", khps, ii, hat) + log.Printf("\t[%4.1f khps] %6d %+v\n", khps, ii, hatOrErr.Msg) } + ii++ } printResults() } diff --git a/example/cmd/server/main.go b/example/cmd/server/main.go index da5f8f4d..1df7a47c 100644 --- a/example/cmd/server/main.go +++ b/example/cmd/server/main.go @@ -47,7 +47,7 @@ func (h *randomHaberdasher) MakeHat(ctx context.Context, size *example.Size) (*e return newRandomHat(size.Inches) } -func (h *randomHaberdasher) MakeHats(ctx context.Context, req *example.MakeHatsReq) (example.HatStream, error) { +func (h *randomHaberdasher) MakeHats(ctx context.Context, req *example.MakeHatsReq) (<-chan example.HatOrError, error) { if req.Quantity < 0 { return nil, errNegativeQuantity } @@ -58,17 +58,17 @@ func (h *randomHaberdasher) MakeHats(ctx context.Context, req *example.MakeHatsR ch := make(chan example.HatOrError, 100) // NB: the size of this buffer can make a big difference! go func() { + defer close(ch) for ii := int32(0); ii < req.Quantity; ii++ { hat, err := newRandomHat(req.Inches) select { case <-ctx.Done(): return - case ch <- example.HatOrError{Hat: hat, Err: err}: + case ch <- example.HatOrError{Msg: hat, Err: err}: } } - close(ch) }() - return example.NewHatStream(ch), nil + return ch, nil } func main() { diff --git a/example/cmd/server/main_test.go b/example/cmd/server/main_test.go index 56a96a29..a2ce1feb 100644 --- a/example/cmd/server/main_test.go +++ b/example/cmd/server/main_test.go @@ -16,7 +16,6 @@ package main import ( "context" "fmt" - "io" "log" "net/http" "os" @@ -116,8 +115,8 @@ func TestMakeHatsTooSmall(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = hatStream.Next(context.Background()) - err = compareErrors(err, errTooSmall) + hatOrErr := <-hatStream + err = compareErrors(hatOrErr.Err, errTooSmall) if err != nil { t.Fatal(err) } @@ -148,14 +147,11 @@ func TestMakeHatsLargeQuantities(t *testing.T) { t.Fatalf(`MakeHats request failed: %#v (hatStream=%#v)`, err, hatStream) } ii := int32(0) - for ; true; ii++ { - _, err = hatStream.Next(context.Background()) - if err == io.EOF { - break - } - if err != nil { - t.Fatal(err) + for hatOrErr := range hatStream { + if hatOrErr.Err != nil { + t.Fatal(hatOrErr.Err) } + ii++ } if ii != re.req.Quantity { t.Fatalf(`Expected to receive %d hats, got %d`, re.req.Quantity, ii) @@ -189,14 +185,11 @@ func benchmarkMakeHats(b *testing.B, cc example.Haberdasher) { b.Fatal(err) } ii := 0 - for ; true; ii++ { - _, err = hatStream.Next(context.Background()) - if err != nil { - if err == io.EOF { - break - } - b.Fatal(err) + for hatOrErr := range hatStream { + if hatOrErr.Err != nil { + b.Fatal(hatOrErr.Err) } + ii++ } if ii != b.N { b.Fatalf(`Expected to receive %d hats, got %d`, b.N, ii) diff --git a/example/service.twirp.go b/example/service.twirp.go index af1797b9..e26b62b1 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -39,7 +39,7 @@ type Haberdasher interface { // MakeHat produces a hat of mysterious, randomly-selected color! MakeHat(ctx context.Context, in *Size) (*Hat, error) - MakeHats(ctx context.Context, in *MakeHatsReq) (HatStream, error) + MakeHats(ctx context.Context, in *MakeHatsReq) (<-chan HatOrError, error) } // =========================== @@ -80,7 +80,7 @@ func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat return out, err } -func (c *haberdasherProtobufClient) MakeHats(ctx context.Context, in *MakeHatsReq) (HatStream, error) { +func (c *haberdasherProtobufClient) MakeHats(ctx context.Context, in *MakeHatsReq) (<-chan HatOrError, error) { ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") ctx = ctxsetters.WithMethodName(ctx, "MakeHats") @@ -108,13 +108,29 @@ func (c *haberdasherProtobufClient) MakeHats(ctx context.Context, in *MakeHatsRe return nil, errorFromResponse(resp) } - return &protoHatStreamReader{ - prs: protoStreamReader{ + respStream := make(chan HatOrError) + go func() { + defer func() { + resp.Body.Close() + close(respStream) + }() + reader := protoStreamReader{ r: bufio.NewReader(resp.Body), - c: resp.Body, maxSize: 1 << 21, // 1GB - }, - }, nil + } + out := new(Hat) + for { + if err = reader.Read(out); err != nil { + if err == io.EOF { + return + } + respStream <- HatOrError{Err: err} + return + } + respStream <- HatOrError{Msg: out} + } + }() + return respStream, nil } // ======================= @@ -155,7 +171,7 @@ func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, er return out, err } -func (c *haberdasherJSONClient) MakeHats(ctx context.Context, in *MakeHatsReq) (HatStream, error) { +func (c *haberdasherJSONClient) MakeHats(ctx context.Context, in *MakeHatsReq) (<-chan HatOrError, error) { ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") ctx = ctxsetters.WithMethodName(ctx, "MakeHats") @@ -183,14 +199,30 @@ func (c *haberdasherJSONClient) MakeHats(ctx context.Context, in *MakeHatsReq) ( return nil, errorFromResponse(resp) } - jrs, err := newJSONStreamReader(resp.Body) - if err != nil { - return nil, err - } - return &jsonHatStreamReader{ - jrs: jrs, - c: resp.Body, - }, nil + respStream := make(chan HatOrError) + go func() { + defer func() { + resp.Body.Close() + close(respStream) + }() + reader, err := newJSONStreamReader(resp.Body) + if err != nil { + respStream <- HatOrError{Err: err} + return + } + out := new(Hat) + for { + if err = reader.Read(out); err != nil { + if err == io.EOF { + return + } + respStream <- HatOrError{Err: err} + return + } + respStream <- HatOrError{Msg: out} + } + }() + return respStream, nil } // ========================== @@ -448,7 +480,7 @@ func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http } // Call service method - var respContent HatStream + var respContent <-chan HatOrError respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { @@ -502,15 +534,24 @@ func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http // Ignored, for the same reason as in the writeError func _ = writeErr } - respContent.End(twerr) } messages := proto.NewBuffer(nil) for { - msg, err := respContent.Next(ctx) - if err != nil { - writeTrailer(err) - break + var msg *Hat + select { + case <-ctx.Done(): + return + case msgOrErr, open := <-respContent: + if !open { + writeTrailer(io.EOF) + return + } + if msgOrErr.Err != nil { + writeTrailer(msgOrErr.Err) + return + } + msg = msgOrErr.Msg } messages.Reset() @@ -519,14 +560,14 @@ func (s *haberdasherServer) serveMakeHatsProtobuf(ctx context.Context, resp http if err != nil { err = wrapErr(err, "failed to marshal proto message") writeTrailer(err) - break + return } _, err = resp.Write(messages.Bytes()) if err != nil { err = wrapErr(err, "failed to send proto message") writeTrailer(err) // likely to fail on write, but try anyway to ensure ctx gets error code set for responseSent hook - break + return } if canFlush { @@ -551,73 +592,11 @@ func (s *haberdasherServer) ProtocGenTwirpVersion() string { return "v5.3.0" } -// HatStream represents a stream of Hat messages. -type HatStream interface { - Next(context.Context) (*Hat, error) - End(error) -} - -type protoHatStreamReader struct { - prs protoStreamReader -} - -func (r protoHatStreamReader) Next(context.Context) (*Hat, error) { - out := new(Hat) - err := r.prs.Read(out) - if err != nil { - return nil, err - } - return out, nil -} - -func (r protoHatStreamReader) End(error) { _ = r.prs.c.Close() } - -type jsonHatStreamReader struct { - jrs *jsonStreamReader - c io.Closer -} - -func (r jsonHatStreamReader) Next(context.Context) (*Hat, error) { - out := new(Hat) - err := r.jrs.Read(out) - if err != nil { - return nil, err - } - return out, nil -} - -func (r jsonHatStreamReader) End(error) { _ = r.c.Close() } - type HatOrError struct { - Hat *Hat + Msg *Hat Err error } -func NewHatStream(ch chan HatOrError) *hatStreamSender { - return &hatStreamSender{ch: ch} -} - -type hatStreamSender struct { - ch <-chan HatOrError -} - -func (ss *hatStreamSender) Next(ctx context.Context) (*Hat, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case v, open := <-ss.ch: - if !open { - return nil, io.EOF - } - if v.Err != nil { - return nil, v.Err - } - return v.Hat, nil - } -} - -func (ss *hatStreamSender) End(err error) {} - // ===== // Utils // ===== @@ -1043,9 +1022,7 @@ func callError(ctx context.Context, h *twirp.ServerHooks, err twirp.Error) conte } type protoStreamReader struct { - r *bufio.Reader - c io.Closer - + r *bufio.Reader maxSize int } @@ -1061,9 +1038,7 @@ func (r protoStreamReader) Read(msg proto.Message) error { trailerTag = (2 << 3) | 2 ) - if tag == trailerTag { - // This is a trailer (twirp error), read it and then close the client - defer r.c.Close() + if tag == trailerTag { // Received a json twirp error or "EOF" // Read the length delimiter l, err := binary.ReadUvarint(r.r) if err != nil { @@ -1180,13 +1155,13 @@ func (r *jsonStreamReader) Read(msg proto.Message) error { var tj twerrJSON err = r.dec.Decode(&tj) if err != nil { + var eof string + if _ = r.dec.Decode(&eof); eof == "EOF" { + return io.EOF + } return err } - if tj.Code == "stream_complete" { - return io.EOF - } - return tj.toTwirpError() } diff --git a/internal/twirptest/importer/importer.twirp.go b/internal/twirptest/importer/importer.twirp.go index cf6f133b..94acaf71 100644 --- a/internal/twirptest/importer/importer.twirp.go +++ b/internal/twirptest/importer/importer.twirp.go @@ -363,73 +363,11 @@ func (s *svc2Server) ProtocGenTwirpVersion() string { return "v5.3.0" } -// MsgStream represents a stream of twirp_internal_twirptest_importable.Msg messages. -type MsgStream interface { - Next(context.Context) (*twirp_internal_twirptest_importable.Msg, error) - End(error) -} - -type protoMsgStreamReader struct { - prs protoStreamReader -} - -func (r protoMsgStreamReader) Next(context.Context) (*twirp_internal_twirptest_importable.Msg, error) { - out := new(twirp_internal_twirptest_importable.Msg) - err := r.prs.Read(out) - if err != nil { - return nil, err - } - return out, nil -} - -func (r protoMsgStreamReader) End(error) { _ = r.prs.c.Close() } - -type jsonMsgStreamReader struct { - jrs *jsonStreamReader - c io.Closer -} - -func (r jsonMsgStreamReader) Next(context.Context) (*twirp_internal_twirptest_importable.Msg, error) { - out := new(twirp_internal_twirptest_importable.Msg) - err := r.jrs.Read(out) - if err != nil { - return nil, err - } - return out, nil -} - -func (r jsonMsgStreamReader) End(error) { _ = r.c.Close() } - type MsgOrError struct { Msg *twirp_internal_twirptest_importable.Msg Err error } -func NewMsgStream(ch chan MsgOrError) *msgStreamSender { - return &msgStreamSender{ch: ch} -} - -type msgStreamSender struct { - ch <-chan MsgOrError -} - -func (ss *msgStreamSender) Next(ctx context.Context) (*twirp_internal_twirptest_importable.Msg, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case v, open := <-ss.ch: - if !open { - return nil, io.EOF - } - if v.Err != nil { - return nil, v.Err - } - return v.Msg, nil - } -} - -func (ss *msgStreamSender) End(err error) {} - // ===== // Utils // ===== @@ -855,9 +793,7 @@ func callError(ctx context.Context, h *twirp.ServerHooks, err twirp.Error) conte } type protoStreamReader struct { - r *bufio.Reader - c io.Closer - + r *bufio.Reader maxSize int } @@ -873,9 +809,7 @@ func (r protoStreamReader) Read(msg proto.Message) error { trailerTag = (2 << 3) | 2 ) - if tag == trailerTag { - // This is a trailer (twirp error), read it and then close the client - defer r.c.Close() + if tag == trailerTag { // Received a json twirp error or "EOF" // Read the length delimiter l, err := binary.ReadUvarint(r.r) if err != nil { @@ -992,13 +926,13 @@ func (r *jsonStreamReader) Read(msg proto.Message) error { var tj twerrJSON err = r.dec.Decode(&tj) if err != nil { + var eof string + if _ = r.dec.Decode(&eof); eof == "EOF" { + return io.EOF + } return err } - if tj.Code == "stream_complete" { - return io.EOF - } - return tj.toTwirpError() } diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index 7eff7458..f4f29ad0 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -337,7 +337,7 @@ func (s *haberdasherServer) ProtocGenTwirpVersion() string { type Streamer interface { Transact(ctx context.Context, in *Req) (*Resp, error) - Download(ctx context.Context, in *Req) (RespStream, error) + Download(ctx context.Context, in *Req) (<-chan RespOrError, error) } // ======================== @@ -380,7 +380,7 @@ func (c *streamerProtobufClient) Transact(ctx context.Context, in *Req) (*Resp, return out, err } -func (c *streamerProtobufClient) Download(ctx context.Context, in *Req) (RespStream, error) { +func (c *streamerProtobufClient) Download(ctx context.Context, in *Req) (<-chan RespOrError, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest") ctx = ctxsetters.WithServiceName(ctx, "Streamer") ctx = ctxsetters.WithMethodName(ctx, "Download") @@ -408,13 +408,29 @@ func (c *streamerProtobufClient) Download(ctx context.Context, in *Req) (RespStr return nil, errorFromResponse(resp) } - return &protoRespStreamReader{ - prs: protoStreamReader{ + respStream := make(chan RespOrError) + go func() { + defer func() { + resp.Body.Close() + close(respStream) + }() + reader := protoStreamReader{ r: bufio.NewReader(resp.Body), - c: resp.Body, maxSize: 1 << 21, // 1GB - }, - }, nil + } + out := new(Resp) + for { + if err = reader.Read(out); err != nil { + if err == io.EOF { + return + } + respStream <- RespOrError{Err: err} + return + } + respStream <- RespOrError{Msg: out} + } + }() + return respStream, nil } // ==================== @@ -457,7 +473,7 @@ func (c *streamerJSONClient) Transact(ctx context.Context, in *Req) (*Resp, erro return out, err } -func (c *streamerJSONClient) Download(ctx context.Context, in *Req) (RespStream, error) { +func (c *streamerJSONClient) Download(ctx context.Context, in *Req) (<-chan RespOrError, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest") ctx = ctxsetters.WithServiceName(ctx, "Streamer") ctx = ctxsetters.WithMethodName(ctx, "Download") @@ -485,14 +501,30 @@ func (c *streamerJSONClient) Download(ctx context.Context, in *Req) (RespStream, return nil, errorFromResponse(resp) } - jrs, err := newJSONStreamReader(resp.Body) - if err != nil { - return nil, err - } - return &jsonRespStreamReader{ - jrs: jrs, - c: resp.Body, - }, nil + respStream := make(chan RespOrError) + go func() { + defer func() { + resp.Body.Close() + close(respStream) + }() + reader, err := newJSONStreamReader(resp.Body) + if err != nil { + respStream <- RespOrError{Err: err} + return + } + out := new(Resp) + for { + if err = reader.Read(out); err != nil { + if err == io.EOF { + return + } + respStream <- RespOrError{Err: err} + return + } + respStream <- RespOrError{Msg: out} + } + }() + return respStream, nil } // ======================= @@ -781,7 +813,7 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re } // Call service method - var respContent RespStream + var respContent <-chan RespOrError respFlusher, canFlush := resp.(http.Flusher) func() { defer func() { @@ -835,15 +867,24 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re // Ignored, for the same reason as in the writeError func _ = writeErr } - respContent.End(twerr) } messages := proto.NewBuffer(nil) for { - msg, err := respContent.Next(ctx) - if err != nil { - writeTrailer(err) - break + var msg *Resp + select { + case <-ctx.Done(): + return + case msgOrErr, open := <-respContent: + if !open { + writeTrailer(io.EOF) + return + } + if msgOrErr.Err != nil { + writeTrailer(msgOrErr.Err) + return + } + msg = msgOrErr.Msg } messages.Reset() @@ -852,14 +893,14 @@ func (s *streamerServer) serveDownloadProtobuf(ctx context.Context, resp http.Re if err != nil { err = wrapErr(err, "failed to marshal proto message") writeTrailer(err) - break + return } _, err = resp.Write(messages.Bytes()) if err != nil { err = wrapErr(err, "failed to send proto message") writeTrailer(err) // likely to fail on write, but try anyway to ensure ctx gets error code set for responseSent hook - break + return } if canFlush { @@ -909,140 +950,16 @@ func (s *streamerServer) ProtocGenTwirpVersion() string { return "v5.3.0" } -// ReqStream represents a stream of Req messages. -type ReqStream interface { - Next(context.Context) (*Req, error) - End(error) -} - -type protoReqStreamReader struct { - prs protoStreamReader -} - -func (r protoReqStreamReader) Next(context.Context) (*Req, error) { - out := new(Req) - err := r.prs.Read(out) - if err != nil { - return nil, err - } - return out, nil -} - -func (r protoReqStreamReader) End(error) { _ = r.prs.c.Close() } - -type jsonReqStreamReader struct { - jrs *jsonStreamReader - c io.Closer -} - -func (r jsonReqStreamReader) Next(context.Context) (*Req, error) { - out := new(Req) - err := r.jrs.Read(out) - if err != nil { - return nil, err - } - return out, nil -} - -func (r jsonReqStreamReader) End(error) { _ = r.c.Close() } - type ReqOrError struct { - Req *Req + Msg *Req Err error } -func NewReqStream(ch chan ReqOrError) *reqStreamSender { - return &reqStreamSender{ch: ch} -} - -type reqStreamSender struct { - ch <-chan ReqOrError -} - -func (ss *reqStreamSender) Next(ctx context.Context) (*Req, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case v, open := <-ss.ch: - if !open { - return nil, io.EOF - } - if v.Err != nil { - return nil, v.Err - } - return v.Req, nil - } -} - -func (ss *reqStreamSender) End(err error) {} - -// RespStream represents a stream of Resp messages. -type RespStream interface { - Next(context.Context) (*Resp, error) - End(error) -} - -type protoRespStreamReader struct { - prs protoStreamReader -} - -func (r protoRespStreamReader) Next(context.Context) (*Resp, error) { - out := new(Resp) - err := r.prs.Read(out) - if err != nil { - return nil, err - } - return out, nil -} - -func (r protoRespStreamReader) End(error) { _ = r.prs.c.Close() } - -type jsonRespStreamReader struct { - jrs *jsonStreamReader - c io.Closer -} - -func (r jsonRespStreamReader) Next(context.Context) (*Resp, error) { - out := new(Resp) - err := r.jrs.Read(out) - if err != nil { - return nil, err - } - return out, nil -} - -func (r jsonRespStreamReader) End(error) { _ = r.c.Close() } - type RespOrError struct { - Resp *Resp - Err error -} - -func NewRespStream(ch chan RespOrError) *respStreamSender { - return &respStreamSender{ch: ch} -} - -type respStreamSender struct { - ch <-chan RespOrError -} - -func (ss *respStreamSender) Next(ctx context.Context) (*Resp, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case v, open := <-ss.ch: - if !open { - return nil, io.EOF - } - if v.Err != nil { - return nil, v.Err - } - return v.Resp, nil - } + Msg *Resp + Err error } -func (ss *respStreamSender) End(err error) {} - // ===== // Utils // ===== @@ -1468,9 +1385,7 @@ func callError(ctx context.Context, h *twirp.ServerHooks, err twirp.Error) conte } type protoStreamReader struct { - r *bufio.Reader - c io.Closer - + r *bufio.Reader maxSize int } @@ -1486,9 +1401,7 @@ func (r protoStreamReader) Read(msg proto.Message) error { trailerTag = (2 << 3) | 2 ) - if tag == trailerTag { - // This is a trailer (twirp error), read it and then close the client - defer r.c.Close() + if tag == trailerTag { // Received a json twirp error or "EOF" // Read the length delimiter l, err := binary.ReadUvarint(r.r) if err != nil { @@ -1605,13 +1518,13 @@ func (r *jsonStreamReader) Read(msg proto.Message) error { var tj twerrJSON err = r.dec.Decode(&tj) if err != nil { + var eof string + if _ = r.dec.Decode(&eof); eof == "EOF" { + return io.EOF + } return err } - if tj.Code == "stream_complete" { - return io.EOF - } - return tj.toTwirpError() } diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index e829c1d0..ac4b4b5c 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -752,8 +752,6 @@ func (t *twirp) generateUtils() { func (t *twirp) generateStreamUtils() { t.P(`type protoStreamReader struct {`) t.P(` r *`, t.pkgs["bufio"], `.Reader`) - t.P(` c `, t.pkgs["io"], `.Closer`) - t.P(``) t.P(` maxSize int`) t.P(`}`) t.P(``) @@ -769,9 +767,7 @@ func (t *twirp) generateStreamUtils() { t.P(` trailerTag = (2 << 3) | 2`) t.P(` )`) t.P(``) - t.P(` if tag == trailerTag {`) - t.P(` // This is a trailer (twirp error), read it and then close the client`) - t.P(` defer r.c.Close()`) + t.P(` if tag == trailerTag { // Received a json twirp error or "EOF"`) t.P(` // Read the length delimiter`) t.P(` l, err := binary.ReadUvarint(r.r)`) t.P(` if err != nil {`) @@ -891,13 +887,13 @@ func (t *twirp) generateStreamUtils() { t.P(` var tj twerrJSON`) t.P(` err = r.dec.Decode(&tj)`) t.P(` if err != nil {`) + t.P(` var eof string`) + t.P(` if _ = r.dec.Decode(&eof); eof == "EOF" {`) + t.P(` return io.EOF`) + t.P(` }`) t.P(` return err`) t.P(` }`) t.P(``) - t.P(` if tj.Code == "stream_complete" {`) - t.P(` return `, t.pkgs["io"], `.EOF`) - t.P(` }`) - t.P(``) t.P(` return tj.toTwirpError()`) t.P(`}`) t.P() @@ -1062,69 +1058,80 @@ func (t *twirp) generateClient(name string, file *descriptor.FileDescriptorProto sig := t.signature(method) if sig != "" { t.P(`func (c *`, structName, `) `, sig, " {") - t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithPackageName(ctx, "`, pkgName, `")`) - t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithServiceName(ctx, "`, servName, `")`) - t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithMethodName(ctx, "`, methName, `")`) + t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithPackageName(ctx, "`, pkgName, `")`) + t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithServiceName(ctx, "`, servName, `")`) + t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithMethodName(ctx, "`, methName, `")`) switch methodRPCType(method) { case unary: - t.P(` out := new(`, outputType, `)`) - t.P(` err := do`, name, `Request(ctx, c.client, c.urls[`, strconv.Itoa(i), `], in, out)`) - t.P(` return out, err`) - t.P(`}`) - t.P() + t.P(` out := new(`, outputType, `)`) + t.P(` err := do`, name, `Request(ctx, c.client, c.urls[`, strconv.Itoa(i), `], in, out)`) + t.P(` return out, err`) case download: - t.P(` reqBodyBytes, err := `, t.pkgs["proto"], `.Marshal(in)`) - t.P(` if err != nil {`) + t.P(` reqBodyBytes, err := `, t.pkgs["proto"], `.Marshal(in)`) + t.P(` if err != nil {`) t.P(` return nil, clientError("failed to marshal proto request", err)`) - t.P(` }`) - t.P(` reqBody := `, t.pkgs["bytes"], `.NewBuffer(reqBodyBytes)`) - t.P(` if err = ctx.Err(); err != nil {`) - t.P(` return nil, clientError("aborted because context was done", err)`) - t.P(` }`) + t.P(` }`) + t.P(` reqBody := `, t.pkgs["bytes"], `.NewBuffer(reqBodyBytes)`) + t.P(` if err = ctx.Err(); err != nil {`) + t.P(` return nil, clientError("aborted because context was done", err)`) + t.P(` }`) t.P(``) if name == "Protobuf" { - t.P(` req, err := newRequest(ctx, c.urls[`, strconv.Itoa(i), `], reqBody, "application/protobuf")`) + t.P(` req, err := newRequest(ctx, c.urls[`, strconv.Itoa(i), `], reqBody, "application/protobuf")`) } else { - t.P(` req, err := newRequest(ctx, c.urls[`, strconv.Itoa(i), `], reqBody, "application/json")`) + t.P(` req, err := newRequest(ctx, c.urls[`, strconv.Itoa(i), `], reqBody, "application/json")`) } - t.P(` if err != nil {`) - t.P(` return nil, clientError("could not build request", err)`) - t.P(` }`) - t.P(` resp, err := c.client.Do(req)`) - t.P(` if err != nil {`) - t.P(` return nil, clientError("failed to do request", err)`) - t.P(` }`) - t.P(` if err = ctx.Err(); err != nil {`) - t.P(` return nil, clientError("aborted because context was done", err)`) - t.P(` }`) - t.P(` if resp.StatusCode != 200 {`) - t.P(` return nil, errorFromResponse(resp)`) - t.P(` }`) + t.P(` if err != nil {`) + t.P(` return nil, clientError("could not build request", err)`) + t.P(` }`) + t.P(` resp, err := c.client.Do(req)`) + t.P(` if err != nil {`) + t.P(` return nil, clientError("failed to do request", err)`) + t.P(` }`) + t.P(` if err = ctx.Err(); err != nil {`) + t.P(` return nil, clientError("aborted because context was done", err)`) + t.P(` }`) + t.P(` if resp.StatusCode != 200 {`) + t.P(` return nil, errorFromResponse(resp)`) + t.P(` }`) t.P(``) + t.P(` respStream := make(chan `, typeOrErrorFromType(outputType), `)`) + t.P(` go func() {`) + t.P(` defer func() {`) + t.P(` resp.Body.Close()`) + t.P(` close(respStream)`) + t.P(` }()`) if name == "Protobuf" { - t.P(` return &proto`, withoutPackageName(outputType), `StreamReader{`) - t.P(` prs: protoStreamReader{`) - t.P(` r: `, t.pkgs["bufio"], `.NewReader(resp.Body),`) - t.P(` c: resp.Body,`) - t.P(` maxSize: 1 << 21, // 1GB`) - t.P(` },`) - t.P(` }, nil`) - t.P(`}`) + t.P(` reader := protoStreamReader{`) + t.P(` r: `, t.pkgs["bufio"], `.NewReader(resp.Body),`) + t.P(` maxSize: 1 << 21, // 1GB`) + t.P(` }`) } else { - t.P(` jrs, err := newJSONStreamReader(resp.Body)`) - t.P(` if err != nil {`) - t.P(` return nil, err`) - t.P(` }`) - t.P(` return &json`, withoutPackageName(outputType), `StreamReader{`) - t.P(` jrs: jrs,`) - t.P(` c: resp.Body,`) - t.P(` }, nil`) - t.P(`}`) + t.P(` reader, err := newJSONStreamReader(resp.Body)`) + t.P(` if err != nil {`) + t.P(` respStream <- `, typeOrErrorFromType(outputType), `{Err: err}`) + t.P(` return`) + t.P(` }`) } + t.P(` out := new(`, outputType, `)`) + t.P(` for {`) + t.P(` if err = reader.Read(out); err != nil {`) + t.P(` if err == `, t.pkgs["io"], `.EOF {`) + t.P(` return`) + t.P(` }`) + t.P(` respStream <- `, typeOrErrorFromType(outputType), `{Err: err}`) + t.P(` return`) + t.P(` }`) + t.P(` respStream <- `, typeOrErrorFromType(outputType), `{Msg: out}`) + t.P(` }`) + t.P(` }()`) + t.P(` return respStream, nil`) default: - t.P(` return nil, nil}`) + t.P(` return nil, nil`) } + t.P(`}`) + t.P() } } } @@ -1259,6 +1266,13 @@ func (t *twirp) generateServerMethod(service *descriptor.ServiceDescriptorProto, func (t *twirp) generateServerJSONMethod(service *descriptor.ServiceDescriptorProto, method *descriptor.MethodDescriptorProto) { servStruct := serviceStruct(service) methName := stringutils.CamelCase(method.GetName()) + rpcType := methodRPCType(method) + var respType string + if rpcType == download || rpcType == bidirectional { + respType = t.methodOutputType(method) + } else { + respType = `*` + t.methodOutputType(method) + } t.P(`func (s *`, servStruct, `) serve`, methName, `JSON(ctx `, t.pkgs["context"], `.Context, resp `, t.pkgs["http"], `.ResponseWriter, req *`, t.pkgs["http"], `.Request) {`) @@ -1281,7 +1295,7 @@ func (t *twirp) generateServerJSONMethod(service *descriptor.ServiceDescriptorPr t.P(` }`) t.P() t.P(` // Call service method`) - t.P(` var respContent *`, t.methodOutputType(method)) + t.P(` var respContent ` + respType) t.P(` func() {`) t.P(` defer func() {`) t.P(` // In case of a panic, serve a 500 error and then panic.`) @@ -1446,15 +1460,24 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` // Ignored, for the same reason as in the writeError func`) t.P(` _ = writeErr`) t.P(` }`) - t.P(` respContent.End(twerr)`) t.P(` }`) t.P(``) t.P(` messages := `, t.pkgs["proto"], `.NewBuffer(nil)`) t.P(` for {`) - t.P(` msg, err := respContent.Next(ctx)`) - t.P(` if err != nil {`) - t.P(` writeTrailer(err)`) - t.P(` break`) + t.P(` var msg *` + t.goTypeName(method.GetOutputType())) + t.P(` select {`) + t.P(` case <-ctx.Done():`) + t.P(` return`) + t.P(` case msgOrErr, open := <-respContent:`) + t.P(` if !open {`) + t.P(` writeTrailer(` + t.pkgs["io"] + `.EOF)`) + t.P(` return`) + t.P(` }`) + t.P(` if msgOrErr.Err != nil {`) + t.P(` writeTrailer(msgOrErr.Err)`) + t.P(` return`) + t.P(` }`) + t.P(` msg = msgOrErr.Msg`) t.P(` }`) t.P(``) t.P(` messages.Reset()`) @@ -1463,14 +1486,14 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` if err != nil {`) t.P(` err = wrapErr(err, "failed to marshal proto message")`) t.P(` writeTrailer(err)`) - t.P(` break`) + t.P(` return`) t.P(` }`) t.P(``) t.P(` _, err = resp.Write(messages.Bytes())`) t.P(` if err != nil {`) t.P(` err = wrapErr(err, "failed to send proto message")`) t.P(` writeTrailer(err) // likely to fail on write, but try anyway to ensure ctx gets error code set for responseSent hook`) - t.P(` break`) + t.P(` return`) t.P(` }`) t.P(``) t.P(` if canFlush {`) @@ -1526,83 +1549,12 @@ func (t *twirp) generateStreamType(typeName string) { t.streamTypes[typeName] = true }() - typeNameWithoutPackage := withoutPackageName(typeName) - streamTypeName := withoutPackageName(typeName) + "Stream" - t.P(`// `, streamTypeName, ` represents a stream of `, typeName, ` messages.`) - t.P(`type `, streamTypeName, ` interface {`) - t.P(` Next(context.Context) (*`, typeName, `, error)`) - t.P(` End(error)`) - t.P(`}`) - t.P() - - protoReaderTypeName := "proto" + streamTypeName + "Reader" - - t.P(`type `, protoReaderTypeName, ` struct {`) - t.P(` prs protoStreamReader`) - t.P(`}`) - t.P() - t.P(`func (r `, protoReaderTypeName, `) Next(context.Context) (*`, typeName, `, error) {`) - t.P(` out := new(`, typeName, `)`) - t.P(` err := r.prs.Read(out)`) - t.P(` if err != nil {`) - t.P(` return nil, err`) - t.P(` }`) - t.P(` return out, nil`) - t.P(`}`) - t.P() - t.P(`func (r `, protoReaderTypeName, `) End(error) { _ = r.prs.c.Close() }`) - t.P() - - jsonReaderTypeName := "json" + streamTypeName + "Reader" - t.P(`type `, jsonReaderTypeName, ` struct {`) - t.P(` jrs *jsonStreamReader`) - t.P(` c `, t.pkgs["io"], `.Closer`) - t.P(`}`) - t.P() - t.P(`func (r `, jsonReaderTypeName, `) Next(`, t.pkgs["context"], `.Context) (*`, typeName, `, error) {`) - t.P(` out := new(`, typeName, `)`) - t.P(` err := r.jrs.Read(out)`) - t.P(` if err != nil {`) - t.P(` return nil, err`) - t.P(` }`) - t.P(` return out, nil`) - t.P(`}`) - t.P() - t.P(`func (r `, jsonReaderTypeName, `) End(error) { _ = r.c.Close() }`) - t.P() - - typeOrErrorTypeName := typeNameWithoutPackage + "OrError" + typeOrErrorTypeName := typeOrErrorFromType(typeName) t.P(`type `, typeOrErrorTypeName, ` struct {`) - t.P(` `, typeNameWithoutPackage, ` *`, typeName) + t.P(` Msg *`, typeName) t.P(` Err error`) t.P(`}`) t.P() - streamSenderName := unexported(streamTypeName + "Sender") - t.P(`func New`, streamTypeName, `(ch chan `, typeOrErrorTypeName, `) *`, streamSenderName, ` {`) - t.P(` return &`, streamSenderName, `{ch: ch}`) - t.P(`}`) - t.P() - t.P(`type `, streamSenderName, ` struct {`) - t.P(` ch <-chan `, typeOrErrorTypeName) - t.P(`}`) - t.P() - t.P(`func (ss *`, streamSenderName, `) Next(ctx context.Context) (*`, typeName, `, error) {`) - t.P(` select {`) - t.P(` case <-ctx.Done():`) - t.P(` return nil, ctx.Err()`) - t.P(` case v, open := <-ss.ch:`) - t.P(` if !open {`) - t.P(` return nil, io.EOF`) - t.P(` }`) - t.P(` if v.Err != nil {`) - t.P(` return nil, v.Err`) - t.P(` }`) - t.P(` return v.`, typeNameWithoutPackage, `, nil`) - t.P(` }`) - t.P(`}`) - t.P() - t.P(`func (ss *`, streamSenderName, `) End(err error) {}`) - t.P() } // serviceMetadataVarName is the variable name used in generated code to refer @@ -1715,7 +1667,7 @@ func (t *twirp) methodInputType(method *descriptor.MethodDescriptorProto) string func (t *twirp) methodOutputType(method *descriptor.MethodDescriptorProto) string { name := t.goTypeName(method.GetOutputType()) if method.GetServerStreaming() { - name = withoutPackageName(name) + "Stream" + name = streamTypeFromType(name) } return name } @@ -1785,6 +1737,14 @@ func withoutPackageName(v string) string { return v[idx+1:] } +func typeOrErrorFromType(name string) string { + return withoutPackageName(name) + "OrError" +} + +func streamTypeFromType(name string) string { + return "<-chan " + typeOrErrorFromType(name) +} + func fileDescSliceContains(slice []*descriptor.FileDescriptorProto, f *descriptor.FileDescriptorProto) bool { for _, sf := range slice { if f == sf { From b8fadd598fa158db659b12aa54b298c183cd71b4 Mon Sep 17 00:00:00 2001 From: Mike Lemmon Date: Thu, 21 Jun 2018 10:40:44 -0700 Subject: [PATCH 7/7] Fix a bug in writeTrailer where EOF never got written to the response --- protoc-gen-twirp/generator.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index ac4b4b5c..c88c0b3d 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -1440,6 +1440,13 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` trailer := `, t.pkgs["proto"], `.NewBuffer(nil)`) t.P(` _ = trailer.EncodeVarint((2 << 3) | 2) // field tag`) t.P(` writeTrailer := func(err error) {`) + t.P(` defer func() {`) + t.P(` _, writeErr := resp.Write(trailer.Bytes())`) + t.P(` if writeErr != nil {`) + t.P(` // Ignored, for the same reason as in the writeError func`) + t.P(` _ = writeErr`) + t.P(` }`) + t.P(` }()`) t.P(` if err == `, t.pkgs["io"], `.EOF {`) t.P(` trailer.EncodeStringBytes("EOF")`) t.P(` return`) @@ -1455,11 +1462,6 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` if encodeErr := trailer.EncodeStringBytes(string(marshalErrorToJSON(twerr))); encodeErr != nil {`) t.P(` _ = trailer.EncodeStringBytes("{\"code\":\"" + string(`, t.pkgs["twirp"], `.Internal) + "\",\"msg\":\"There was an error but it could not be serialized into JSON\"}") // fallback`) t.P(` }`) - t.P(` _, writeErr := resp.Write(trailer.Bytes())`) - t.P(` if writeErr != nil {`) - t.P(` // Ignored, for the same reason as in the writeError func`) - t.P(` _ = writeErr`) - t.P(` }`) t.P(` }`) t.P(``) t.P(` messages := `, t.pkgs["proto"], `.NewBuffer(nil)`)