From 4fe37fb1aa9256914e85b67e45b01c0657aa9c11 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Sun, 10 Dec 2023 16:44:03 +0100 Subject: [PATCH] Support webrpc SSE streaming https://github.com/webrpc/webrpc/pull/237 --- client.go.tmpl | 190 ++++++++++++++++++++++++++++++++++++++---------- errors.go.tmpl | 6 +- helpers.go.tmpl | 4 +- imports.go.tmpl | 11 ++- main.go.tmpl | 17 ++++- server.go.tmpl | 120 +++++++++++++++++++++++------- types.go.tmpl | 130 ++++++++++++++++++++++++++++++--- 7 files changed, 392 insertions(+), 86 deletions(-) diff --git a/client.go.tmpl b/client.go.tmpl index 9d4ddc9..d038a39 100644 --- a/client.go.tmpl +++ b/client.go.tmpl @@ -1,41 +1,47 @@ {{define "client"}} {{- $typeMap := .TypeMap -}} {{- $typePrefix := .TypePrefix -}} -{{- if .Services -}} +{{- $services := .Services -}} +{{- $opts := .Opts }} + +{{- if $services -}} // // Client // -{{range .Services -}} +{{range $services -}} const {{.Name}}PathPrefix = "/rpc/{{.Name}}/" {{end}} -{{- range .Services -}} +{{- range $_, $service := $services -}} -{{ $serviceName := (printf "%sClient" (.Name | firstLetterToLower)) }} -type {{$serviceName}} struct { +{{- $serviceNameClient := (printf "%sClient" ($service.Name | firstLetterToLower)) }} +{{- $ServiceNameClient := (printf "%sClient" ($service.Name | firstLetterToUpper)) }} +type {{$serviceNameClient}} struct { client HTTPClient - urls [{{len .Methods}}]string + urls [{{len $service.Methods}}]string } -func New{{.Name | firstLetterToUpper }}Client(addr string, client HTTPClient) {{.Name}} { - prefix := urlBase(addr) + {{.Name}}PathPrefix - urls := [{{len .Methods}}]string{ - {{- range .Methods}} - prefix + "{{.Name}}", +func New{{$ServiceNameClient}}(addr string, client HTTPClient) {{$ServiceNameClient}} { + prefix := urlBase(addr) + {{$service.Name}}PathPrefix + urls := [{{len $service.Methods}}]string{ + {{- range $_, $method := $service.Methods}} + prefix + "{{$method.Name}}", {{- end}} } - return &{{$serviceName}}{ + return &{{$serviceNameClient}}{ client: client, urls: urls, } } -{{- range $i, $method := .Methods -}} +{{- range $i, $method := $service.Methods -}} {{- $inputs := $method.Inputs -}} {{- $outputs := $method.Outputs }} -func (c *{{$serviceName}}) {{.Name}}(ctx context.Context{{range $_, $input := $inputs}}, {{$input.Name}} {{template "field" dict "Name" $input.Name "Type" $input.Type "Optional" $input.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $input.Meta}}{{end}}) {{if len .Outputs}}({{end}}{{range $i, $output := .Outputs}}{{template "field" dict "Name" $output.Name "Type" $output.Type "Optional" $output.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $output.Meta}}{{if lt $i (len $method.Outputs)}}, {{end}}{{end}}error{{if len .Outputs}}){{end}} { +{{ if eq $method.StreamOutput false -}} + +func (c *{{$serviceNameClient}}) {{$method.Name}}(ctx context.Context{{range $_, $input := $inputs}}, {{$input.Name}} {{template "field" dict "Name" $input.Name "Type" $input.Type "Optional" $input.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $input.Meta}}{{end}}) {{if len $outputs}}({{end}}{{range $i, $output := $outputs}}{{template "field" dict "Name" $output.Name "Type" $output.Type "Optional" $output.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $output.Meta}}{{if lt $i (len $method.Outputs)}}, {{end}}{{end}}error{{if len $outputs}}){{end}} { {{- $inputVar := "nil" -}} {{- $outputVar := "nil" -}} {{- if $inputs | len}} @@ -49,15 +55,129 @@ func (c *{{$serviceName}}) {{.Name}}(ctx context.Context{{range $_, $input := $i {{- if $outputs | len}} {{- $outputVar = "&out"}} out := struct { - {{- range $i, $output := .Outputs}} + {{- range $i, $output := $outputs}} Ret{{$i}} {{template "field" dict "Name" $output.Name "Type" $output.Type "Optional" $output.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $output.Meta "JsonTags" true}} {{- end}} }{} {{ end }} - err := doJSONRequest(ctx, c.client, c.urls[{{$i}}], {{$inputVar}}, {{$outputVar}}) - return {{range $i, $output := .Outputs}}out.Ret{{$i}}, {{end}}err + + resp, err := doHTTPRequest(ctx, c.client, c.urls[{{$i}}], {{$inputVar}}, {{$outputVar}}) + defer func() { + if resp != nil { + cerr := resp.Body.Close() + if err == nil && cerr != nil { + err = ErrWebrpcRequestFailed.WithCause(fmt.Errorf("failed to close response body: %w", cerr)) + } + } + }() + + return {{range $i, $output := $outputs}}out.Ret{{$i}}, {{end}}err } + +{{- else -}} + +func (c *{{$serviceNameClient}}) {{$method.Name}}(ctx context.Context{{range $_, $input := $inputs}}, {{$input.Name}} {{template "field" dict "Name" $input.Name "Type" $input.Type "Optional" $input.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $input.Meta}}{{end}}) ({{$method.Name}}StreamReader, error) { + {{- $inputVar := "nil" -}} + {{- if $inputs | len}} + {{- $inputVar = "in"}} + in := struct { + {{- range $i, $input := $inputs}} + Arg{{$i}} {{template "field" dict "Name" $input.Name "Type" $input.Type "Optional" $input.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $input.Meta "JsonTags" true}} + {{- end}} + }{ {{- range $i, $input := $inputs}}{{if gt $i 0}}, {{end}}{{$input.Name}}{{end}}} + {{- end}} + + resp, err := doHTTPRequest(ctx, c.client, c.urls[{{$i}}], {{$inputVar}}, nil) + if err != nil { + if resp != nil { + resp.Body.Close() + } + return nil, err + } + + buf := bufio.NewReader(resp.Body) + return &{{$method.Name | firstLetterToLower}}StreamReader{streamReader{ctx: ctx, c: resp.Body, r: buf, d: json.NewDecoder(buf)}}, nil +} + +{{- end -}} + {{- end -}} + +{{- range $i, $method := $service.Methods -}} +{{ if eq $method.StreamOutput true }} + +type subscribeMessagesStreamReader struct { + streamReader +} + +func (r *subscribeMessagesStreamReader) Read() ({{end}}{{range $i, $output := $method.Outputs}}{{template "field" dict "Name" $output.Name "Type" $output.Type "Optional" $output.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $output.Meta}}{{if lt $i (len $method.Outputs)}}, {{end}}error) { + var out struct{ + Ret0 *Message `json:"message"` + WebRPCError *WebRPCError `json:"webrpcError"` + } + + err := r.streamReader.read(&out) + if err != nil { + return nil, err + } + + if out.WebRPCError != nil { + return nil, out.WebRPCError + } + + return out.Ret0, nil +} +{{- end }} +{{- end }} +{{- end }} + +{{- if $opts.streaming }} + +type streamReader struct { + ctx context.Context + c io.Closer + r *bufio.Reader + d *json.Decoder +} + +func (r *streamReader) read(v interface {}) error { + for { + // Read newlines (keep-alive pings) and unblock decoder on ctx timeout. + select { + case <-r.ctx.Done(): + r.c.Close() + return ErrWebrpcClientDisconnected.WithCause(r.ctx.Err()) + default: + } + + b, err := r.r.ReadByte() + if err != nil { + return r.handleReadError(err) + } + if b != '\n' { + r.r.UnreadByte() + break + } + } + + if err := r.d.Decode(&v); err != nil { + return r.handleReadError(err) + } + + return nil +} + +func (r *streamReader) handleReadError(err error) error { + defer r.c.Close() + if errors.Is(err, io.EOF) { + return ErrWebrpcStreamFinished.WithCause(err) + } + if errors.Is(err, io.ErrUnexpectedEOF) { + return ErrWebrpcStreamLost.WithCause(err) + } + return ErrWebrpcBadResponse.WithCause(fmt.Errorf("reading stream: %w", err)) +} + {{- end }} // HTTPClient is the interface used by generated clients to send HTTP requests. @@ -100,65 +220,55 @@ func newRequest(ctx context.Context, url string, reqBody io.Reader, contentType return req, nil } -// doJSONRequest is common code to make a request to the remote service. -func doJSONRequest(ctx context.Context, client HTTPClient, url string, in, out interface{}) error { +// doHTTPRequest is common code to make a request to the remote service. +func doHTTPRequest(ctx context.Context, client HTTPClient, url string, in, out interface{}) (*http.Response, error) { reqBody, err := json.Marshal(in) if err != nil { - return ErrWebrpcRequestFailed.WithCause(fmt.Errorf("failed to marshal JSON body: %w", err)) + return nil, ErrWebrpcRequestFailed.WithCause(fmt.Errorf("failed to marshal JSON body: %w", err)) } if err = ctx.Err(); err != nil { - return ErrWebrpcRequestFailed.WithCause(fmt.Errorf("aborted because context was done: %w", err)) + return nil, ErrWebrpcRequestFailed.WithCause(fmt.Errorf("aborted because context was done: %w", err)) } req, err := newRequest(ctx, url, bytes.NewBuffer(reqBody), "application/json") if err != nil { - return ErrWebrpcRequestFailed.WithCause(fmt.Errorf("could not build request: %w", err)) + return nil, ErrWebrpcRequestFailed.WithCause(fmt.Errorf("could not build request: %w", err)) } + resp, err := client.Do(req) if err != nil { - return ErrWebrpcRequestFailed.WithCause(err) - } - - defer func() { - cerr := resp.Body.Close() - if err == nil && cerr != nil { - err = ErrWebrpcRequestFailed.WithCause(fmt.Errorf("failed to close response body: %w", cerr)) - } - }() - - if err = ctx.Err(); err != nil { - return ErrWebrpcRequestFailed.WithCause(fmt.Errorf("aborted because context was done: %w", err)) + return nil, ErrWebrpcRequestFailed.WithCause(err) } if resp.StatusCode != 200 { respBody, err := io.ReadAll(resp.Body) if err != nil { - return ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to read server error response body: %w", err)) + return nil, ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to read server error response body: %w", err)) } var rpcErr WebRPCError if err := json.Unmarshal(respBody, &rpcErr); err != nil { - return ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to unmarshal server error: %w", err)) + return nil, ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to unmarshal server error: %w", err)) } if rpcErr.Cause != "" { rpcErr.cause = errors.New(rpcErr.Cause) } - return rpcErr + return nil, rpcErr } if out != nil { respBody, err := io.ReadAll(resp.Body) if err != nil { - return ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to read response body: %w", err)) + return nil, ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to read response body: %w", err)) } err = json.Unmarshal(respBody, &out) if err != nil { - return ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to unmarshal JSON response body: %w", err)) + return nil, ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to unmarshal JSON response body: %w", err)) } } - return nil + return resp, nil } func WithHTTPRequestHeaders(ctx context.Context, h http.Header) (context.Context, error) { diff --git a/errors.go.tmpl b/errors.go.tmpl index 5f7c52e..42395a0 100644 --- a/errors.go.tmpl +++ b/errors.go.tmpl @@ -1,7 +1,7 @@ {{define "errors"}} {{- $webrpcErrors := .WebrpcErrors -}} {{- $schemaErrors := .SchemaErrors -}} -{{- $opts := .Opts }} +{{- $opts := .Opts -}} // // Errors // @@ -82,6 +82,6 @@ var ( {{ printf "Err%s = WebRPCError{Code: %v, Name: %q, Message: %q, HTTPStatus: %v}" $error.Name $error.Code $error.Name $error.Message $error.HTTPStatus}} {{- end}} ) -{{- end}} +{{ end -}} -{{- end }} +{{- end -}} diff --git a/helpers.go.tmpl b/helpers.go.tmpl index f0b430c..2fe60d4 100644 --- a/helpers.go.tmpl +++ b/helpers.go.tmpl @@ -88,6 +88,6 @@ func initializeNils(v reflect.Value) { } } } -{{ end }} +{{- end -}} -{{ end -}} +{{- end -}} diff --git a/imports.go.tmpl b/imports.go.tmpl index cd6e123..966059f 100644 --- a/imports.go.tmpl +++ b/imports.go.tmpl @@ -32,6 +32,15 @@ {{- end -}} {{- end -}} +{{- if $opts.streaming -}} + {{- if $opts.server }} + {{- set $stdlibImports "sync" "" -}} + {{- end -}} + {{- if $opts.client }} + {{- set $stdlibImports "bufio" "" -}} + {{- end -}} +{{- end -}} + {{- /* Import "time" if there's at least one timestamp. */ -}} {{ if $opts.types -}} {{ if eq $opts.importTypesFrom "" -}} @@ -85,7 +94,7 @@ import ( {{if ne $rename ""}}{{$rename}} {{end}}"{{$import}}" {{end}} {{- end }} -{{- end }} +{{- end -}} ) {{- if eq $opts.json "jsoniter" }} diff --git a/main.go.tmpl b/main.go.tmpl index 1aed6f3..e298dd4 100644 --- a/main.go.tmpl +++ b/main.go.tmpl @@ -17,6 +17,15 @@ {{- $typePrefix = (printf "%s." $typePrefix) -}} {{- end -}} +{{- set $opts "" false -}} +{{- range $_, $service := .Services -}} + {{- range $_, $method := $service.Methods -}} + {{ if eq $method.StreamOutput true -}} + {{- set $opts "streaming" true -}} + {{- end -}} + {{- end -}} +{{- end }} + {{- /* Print help on -help. */ -}} {{- if exists .Opts "help" -}} {{- template "help" $opts -}} @@ -93,20 +102,20 @@ func WebRPCSchemaHash() string { {{ template "types" dict "Services" .Services "Types" .Types "TypeMap" $typeMap "TypePrefix" $typePrefix "Opts" $opts }} {{ end -}} -{{- if $opts.server}} +{{- if $opts.server }} {{ template "server" dict "Services" .Services "TypeMap" $typeMap "TypePrefix" $typePrefix "Opts" $opts }} {{ end -}} {{ if $opts.client }} -{{ template "client" dict "Services" .Services "TypeMap" $typeMap "TypePrefix" $typePrefix }} +{{ template "client" dict "Services" .Services "TypeMap" $typeMap "Opts" $opts "TypePrefix" $typePrefix }} {{ end -}} {{ template "helpers" dict "Opts" $opts }} -{{- template "errors" dict "WebrpcErrors" .WebrpcErrors "SchemaErrors" .Errors "Opts" $opts "TypePrefix" $typePrefix }} +{{ template "errors" dict "WebrpcErrors" .WebrpcErrors "SchemaErrors" .Errors "Opts" $opts "TypePrefix" $typePrefix }} {{- if $opts.legacyErrors }} {{ template "legacyErrors" . }} {{- end }} -{{ end }} +{{- end }} diff --git a/server.go.tmpl b/server.go.tmpl index 9622475..0bd6275 100644 --- a/server.go.tmpl +++ b/server.go.tmpl @@ -1,9 +1,10 @@ {{- define "server"}} +{{- $services := .Services -}} {{- $typeMap := .TypeMap -}} {{- $typePrefix := .TypePrefix -}} {{- $opts := .Opts -}} -{{- if .Services -}} +{{- if $services -}} // // Server // @@ -12,16 +13,16 @@ type WebRPCServer interface { http.Handler } -{{- range .Services}} -{{- $name := .Name -}} -{{ $serviceName := (printf "%sServer" (.Name | firstLetterToLower)) }} +{{- range $_, $service := $services}} +{{- $name := $service.Name -}} +{{ $serviceName := (printf "%sServer" (firstLetterToLower $service.Name)) }} type {{$serviceName}} struct { - {{$typePrefix}}{{.Name}} + {{$typePrefix}}{{$service.Name}} OnError func(r *http.Request, rpcErr *WebRPCError) } -func New{{ .Name | firstLetterToUpper }}Server(svc {{$typePrefix}}{{.Name}}) *{{$serviceName}} { +func New{{firstLetterToUpper $service.Name}}Server(svc {{$typePrefix}}{{.Name}}) *{{$serviceName}} { return &{{$serviceName}}{ {{.Name}}: svc, } @@ -43,8 +44,9 @@ func (s *{{$serviceName}}) ServeHTTP(w http.ResponseWriter, r *http.Request) { var handler func(ctx context.Context, w http.ResponseWriter, r *http.Request) switch r.URL.Path { - {{- range .Methods}} - case "/rpc/{{$name}}/{{.Name}}": handler = s.serve{{.Name | firstLetterToUpper}}JSON + {{- range $_, $method := $service.Methods}} + case "/rpc/{{$name}}/{{$method.Name}}": + handler = s.serve{{$method.Name | firstLetterToUpper}}{{if $method.StreamOutput}}NDJSON{{else}}JSON{{end}} {{- end}} default: err := ErrWebrpcBadRoute.WithCause(fmt.Errorf("no handler for path %q", r.URL.Path)) @@ -73,11 +75,12 @@ func (s *{{$serviceName}}) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.sendErrorJSON(w, r, err) } } -{{range .Methods }} -func (s *{{$serviceName}}) serve{{ .Name | firstLetterToUpper }}JSON(ctx context.Context, w http.ResponseWriter, r *http.Request) { - ctx = context.WithValue(ctx, MethodNameCtxKey, "{{.Name}}") +{{range $_, $method := $service.Methods}} +{{- if eq $method.StreamOutput false }} +func (s *{{$serviceName}}) serve{{firstLetterToUpper $method.Name}}JSON(ctx context.Context, w http.ResponseWriter, r *http.Request) { + ctx = context.WithValue(ctx, MethodNameCtxKey, "{{$method.Name}}") - {{ if gt (len .Inputs) 0 -}} + {{ if gt (len $method.Inputs) 0 -}} reqBody, err := io.ReadAll(r.Body) if err != nil { s.sendErrorJSON(w, r, ErrWebrpcBadRequest.WithCause(fmt.Errorf("failed to read request data: %w", err))) @@ -86,7 +89,7 @@ func (s *{{$serviceName}}) serve{{ .Name | firstLetterToUpper }}JSON(ctx context defer r.Body.Close() reqPayload := struct { - {{- range $i, $input := .Inputs}} + {{- range $i, $input := $method.Inputs}} Arg{{$i}} {{template "field" dict "Name" $input.Name "Type" $input.Type "Optional" $input.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $input.Meta "JsonTags" true}} {{- end}} }{} @@ -94,11 +97,10 @@ func (s *{{$serviceName}}) serve{{ .Name | firstLetterToUpper }}JSON(ctx context s.sendErrorJSON(w, r, ErrWebrpcBadRequest.WithCause(fmt.Errorf("failed to unmarshal request data: %w", err))) return } - - {{ end -}} + {{- end }} // Call service method implementation. - {{range $i, $output := .Outputs}}ret{{$i}}, {{end}}err {{if or (eq (len .Inputs) 0) (gt (len .Outputs) 0)}}:{{end}}= s.{{$name}}.{{.Name}}(ctx{{range $i, $_ := .Inputs}}, reqPayload.Arg{{$i}}{{end}}) + {{range $i, $output := $method.Outputs}}ret{{$i}}, {{end}}err {{if or (eq (len $method.Inputs) 0) (gt (len $method.Outputs) 0)}}:{{end}}= s.{{$name}}.{{$method.Name}}(ctx{{range $i, $_ := $method.Inputs}}, reqPayload.Arg{{$i}}{{end}}) if err != nil { rpcErr, ok := err.(WebRPCError) if !ok { @@ -108,16 +110,17 @@ func (s *{{$serviceName}}) serve{{ .Name | firstLetterToUpper }}JSON(ctx context return } - {{- if gt (len .Outputs) 0}} + {{- if gt (len $method.Outputs) 0}} respPayload := struct { - {{- range $i, $output := .Outputs}} + {{- range $i, $output := $method.Outputs}} Ret{{$i}} {{template "field" dict "Name" $output.Name "Type" $output.Type "Optional" $output.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $output.Meta "JsonTags" true}} {{- end}} - }{ {{- range $i, $_ := .Outputs}}{{if gt $i 0}}, {{end}}ret{{$i}}{{end}}} + }{ {{- range $i, $_ := $method.Outputs}}{{if gt $i 0}}, {{end}}ret{{$i}}{{end}}} {{- end}} - {{- if .Outputs | len}} + + {{- if $method.Outputs | len}} {{ if $opts.fixEmptyArrays -}} respBody, err := json.Marshal(initializeNilSlices(respPayload)) {{ else -}} @@ -132,26 +135,93 @@ func (s *{{$serviceName}}) serve{{ .Name | firstLetterToUpper }}JSON(ctx context w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - {{- if .Outputs | len}} + {{- if $method.Outputs | len}} w.Write(respBody) {{- else }} w.Write([]byte("{}")) {{- end}} } -{{end}} +{{ else }} +func (s *{{$serviceName}}) serve{{firstLetterToUpper $method.Name}}NDJSON(ctx context.Context, w http.ResponseWriter, r *http.Request) { + ctx = context.WithValue(ctx, MethodNameCtxKey, "{{$method.Name}}") + + {{ if gt (len $method.Inputs) 0 -}} + reqBody, err := io.ReadAll(r.Body) + if err != nil { + s.sendErrorJSON(w, r, ErrWebrpcBadRequest.WithCause(fmt.Errorf("failed to read request data: %w", err))) + return + } + defer r.Body.Close() + + reqPayload := struct { + {{- range $i, $input := $method.Inputs}} + Arg{{$i}} {{template "field" dict "Name" $input.Name "Type" $input.Type "Optional" $input.Optional "TypeMap" $typeMap "TypePrefix" $typePrefix "TypeMeta" $input.Meta "JsonTags" true}} + {{- end}} + }{} + if err := json.Unmarshal(reqBody, &reqPayload); err != nil { + s.sendErrorJSON(w, r, ErrWebrpcBadRequest.WithCause(fmt.Errorf("failed to unmarshal request data: %w", err))) + return + } + {{- end }} + + f, ok := w.(http.Flusher) + if !ok { + s.sendErrorJSON(w, r, ErrWebrpcInternalError.WithCause(fmt.Errorf("server http.ResponseWriter doesn't support .Flush() method"))) + return + } + + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Content-Type", "application/x-ndjson") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusOK) + + streamWriter := &{{firstLetterToLower $method.Name}}StreamWriter{streamWriter{w: w, f: f, e: json.NewEncoder(w), sendError: s.sendErrorJSON}} + if err := streamWriter.ping(); err != nil { + s.sendErrorJSON(w, r, ErrWebrpcStreamLost.WithCause(fmt.Errorf("failed to establish SSE stream: %w", err))) + return + } + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + go streamWriter.keepAlive(ctx) + + // Call service method implementation. + if err := s.Chat.SubscribeMessages(ctx, reqPayload.Arg0, streamWriter); err != nil { + rpcErr, ok := err.(WebRPCError) + if !ok { + rpcErr = ErrWebrpcEndpoint.WithCause(err) + } + streamWriter.sendError(w, r, rpcErr) + return + } +} +{{- end}} +{{- end}} func (s *{{$serviceName}}) sendErrorJSON(w http.ResponseWriter, r *http.Request, rpcErr WebRPCError) { if s.OnError != nil { s.OnError(r, &rpcErr) } + {{ if $opts.streaming -}} + if w.Header().Get("Content-Type") == "application/x-ndjson" { + out := struct { + WebRPCError WebRPCError `json:"webrpcError"` + }{ WebRPCError: rpcErr } + json.NewEncoder(w).Encode(out) + return + } + {{- end }} + w.Header().Set("Content-Type", "application/json") w.WriteHeader(rpcErr.HTTPStatus) respBody, _ := json.Marshal(rpcErr) w.Write(respBody) } -{{end -}} +{{- end}} func RespondWithError(w http.ResponseWriter, err error) { rpcErr, ok := err.(WebRPCError) @@ -166,5 +236,5 @@ func RespondWithError(w http.ResponseWriter, err error) { w.Write(respBody) } -{{- end -}} -{{- end -}} +{{ end -}} +{{ end -}} diff --git a/types.go.tmpl b/types.go.tmpl index 5333312..13a070a 100644 --- a/types.go.tmpl +++ b/types.go.tmpl @@ -5,7 +5,7 @@ {{- $services := .Services -}} {{- $opts := .Opts -}} -{{- if $types -}} +{{- if $types }} // // Types // @@ -19,17 +19,9 @@ {{template "struct" dict "Name" $type.Name "TypeMap" $typeMap "TypePrefix" $typePrefix "Fields" $type.Fields}} {{ end -}} -{{- end -}} -{{- end -}} +{{ end }} +{{ end -}} -{{- if and $services $opts.types -}} -{{ range $_, $service := $services}} -type {{$service.Name}} interface { - {{- range $_, $method := $service.Methods}} - {{.Name}}(ctx context.Context{{range $_, $input := $method.Inputs}}, {{$input.Name}} {{template "field" dict "Name" $input.Name "Type" $input.Type "TypeMap" $typeMap "TypePrefix" $typePrefix "Optional" $input.Optional "TypeMeta" $input.Meta}}{{end}}) {{if len .Outputs}}({{end}}{{range $i, $output := .Outputs}}{{template "field" dict "Name" $output.Name "Type" $output.Type "TypeMap" $typeMap "TypePrefix" $typePrefix "Optional" $output.Optional "TypeMeta" $output.Meta}}{{if lt $i (len $method.Outputs)}}, {{end}}{{end}}error{{if len $method.Outputs}}){{end}} - {{- end}} -} -{{end}} var WebRPCServices = map[string][]string{ {{- range $_, $service := $services}} "{{$service.Name}}": { @@ -39,6 +31,122 @@ var WebRPCServices = map[string][]string{ }, {{- end}} } + +// +// Server types +// + +{{ range $_, $service := $services -}} +type {{$service.Name}} interface { + {{- range $_, $method := $service.Methods}} + {{ if eq $method.StreamOutput true -}} + {{$method.Name}}(ctx context.Context{{range $_, $input := $method.Inputs}}, {{$input.Name}} {{template "field" dict "Name" $input.Name "Type" $input.Type "TypeMap" $typeMap "TypePrefix" $typePrefix "Optional" $input.Optional "TypeMeta" $input.Meta}}{{end}}, stream {{$method.Name}}StreamWriter) error + {{- else -}} + {{$method.Name}}(ctx context.Context{{range $_, $input := $method.Inputs}}, {{$input.Name}} {{template "field" dict "Name" $input.Name "Type" $input.Type "TypeMap" $typeMap "TypePrefix" $typePrefix "Optional" $input.Optional "TypeMeta" $input.Meta}}{{end}}) {{if len .Outputs}}({{end}}{{range $i, $output := .Outputs}}{{template "field" dict "Name" $output.Name "Type" $output.Type "TypeMap" $typeMap "TypePrefix" $typePrefix "Optional" $output.Optional "TypeMeta" $output.Meta}}{{if lt $i (len $method.Outputs)}}, {{end}}{{end}}error{{if len $method.Outputs}}){{end}} + {{- end -}} + + {{- end}} +} + +{{- range $_, $method := $service.Methods}} +{{ if eq $method.StreamOutput true -}} +type {{$method.Name}}StreamWriter interface { + Write(message *Message) error +} +{{- end }} +{{- end }} + +{{- range $_, $method := $service.Methods}} +{{ if eq $method.StreamOutput true -}} +type {{firstLetterToLower $method.Name}}StreamWriter struct { + streamWriter +} +{{- end -}} +{{- end -}} + +{{- end }} + +{{ if $opts.streaming -}} + +type streamWriter struct { + mu sync.Mutex // Guards concurrent writes to w. + w http.ResponseWriter + f http.Flusher + e *json.Encoder + + sendError func(w http.ResponseWriter, r *http.Request, rpcErr WebRPCError) +} + +const StreamKeepAliveInterval = 10*time.Second + +func (w *streamWriter) keepAlive(ctx context.Context) { + for { + select { + case <-time.After(StreamKeepAliveInterval): + err := w.ping() + if err != nil { + return + } + case <-ctx.Done(): + return + } + } +} + +func (w *streamWriter) ping() error { + defer w.f.Flush() + + w.mu.Lock() + defer w.mu.Unlock() + + _, err := w.w.Write([]byte("\n")) + return err +} + +func (w *streamWriter) write(respPayload interface{}) error { + defer w.f.Flush() + + w.mu.Lock() + defer w.mu.Unlock() + + return w.e.Encode(respPayload) +} + +func (w *streamWriter) Write(message *Message) error { + respPayload := struct { + Arg0 *Message `json:"message"` + }{message} + + return w.write(respPayload) +} +{{- end }} + +// +// Client types +// + +{{ if and $services $opts.types -}} +{{ range $_, $service := $services -}} +type {{$service.Name}}Client interface { + {{- range $_, $method := $service.Methods}} + {{ if eq $method.StreamOutput true -}} + {{$method.Name}}(ctx context.Context{{range $_, $input := $method.Inputs}}, {{$input.Name}} {{template "field" dict "Name" $input.Name "Type" $input.Type "TypeMap" $typeMap "TypePrefix" $typePrefix "Optional" $input.Optional "TypeMeta" $input.Meta}}{{end}}) ({{$method.Name}}StreamReader, error) + {{- else -}} + {{$method.Name}}(ctx context.Context{{range $_, $input := $method.Inputs}}, {{$input.Name}} {{template "field" dict "Name" $input.Name "Type" $input.Type "TypeMap" $typeMap "TypePrefix" $typePrefix "Optional" $input.Optional "TypeMeta" $input.Meta}}{{end}}) {{if len .Outputs}}({{end}}{{range $i, $output := .Outputs}}{{template "field" dict "Name" $output.Name "Type" $output.Type "TypeMap" $typeMap "TypePrefix" $typePrefix "Optional" $output.Optional "TypeMeta" $output.Meta}}{{if lt $i (len $method.Outputs)}}, {{end}}{{end}}error{{if len $method.Outputs}}){{end}} + {{- end -}} + + {{- end}} +} + +{{- range $_, $method := $service.Methods}} +{{ if eq $method.StreamOutput true -}} +type {{$method.Name}}StreamReader interface { + Read() (message *Message, err error) +} +{{- end }} +{{- end }} + +{{- end -}} {{- end -}} {{- end -}}