From 6440d17269dffeba73fa06bbf8e3a790a72afccd Mon Sep 17 00:00:00 2001 From: Jacob Weinstock Date: Mon, 11 Sep 2023 15:48:47 -0600 Subject: [PATCH] Copy req/resp bodies to buffers: Reading from io.Readers and then creating a nop closer and dealing with closing the readers was more difficult than it needed to be. Signed-off-by: Jacob Weinstock --- providers/rpc/http.go | 23 ++++++++--------------- providers/rpc/http_test.go | 10 ++++++---- providers/rpc/logging.go | 13 +++++-------- providers/rpc/rpc.go | 25 +++++++++++++------------ 4 files changed, 32 insertions(+), 39 deletions(-) diff --git a/providers/rpc/http.go b/providers/rpc/http.go index b51d00a2..bde373ab 100644 --- a/providers/rpc/http.go +++ b/providers/rpc/http.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "strings" "time" @@ -45,32 +44,26 @@ func (p *Provider) createRequest(ctx context.Context, rp RequestPayload) (*http. return req, nil } -func (p *Provider) handleResponse(resp *http.Response, reqKeysAndValues []any) (ResponsePayload, error) { +func (p *Provider) handleResponse(statusCode int, headers http.Header, body *bytes.Buffer, reqKeysAndValues []any) (ResponsePayload, error) { kvs := reqKeysAndValues defer func() { if !p.LogNotificationsDisabled { - kvs = append(kvs, responseKVS(resp)...) + kvs = append(kvs, responseKVS(statusCode, headers, body)...) p.Logger.Info("rpc notification details", kvs...) } }() - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return ResponsePayload{}, fmt.Errorf("failed to read response body: %v", err) - } res := ResponsePayload{} - if err := json.Unmarshal(bodyBytes, &res); err != nil { - if resp.StatusCode != http.StatusOK { - return ResponsePayload{}, fmt.Errorf("unexpected status code: %d, response error(optional): %v", resp.StatusCode, res.Error) + if err := json.Unmarshal(body.Bytes(), &res); err != nil { + if statusCode != http.StatusOK { + return ResponsePayload{}, fmt.Errorf("unexpected status code: %d, response error(optional): %v", statusCode, res.Error) } example, _ := json.Marshal(ResponsePayload{ID: 123, Host: p.Host, Error: &ResponseError{Code: 1, Message: "error message"}}) - return ResponsePayload{}, fmt.Errorf("failed to parse response: got: %q, error: %w, expected response json spec: %v", string(bodyBytes), err, string(example)) + return ResponsePayload{}, fmt.Errorf("failed to parse response: got: %q, error: %w, expected response json spec: %v", body.String(), err, string(example)) } - if resp.StatusCode != http.StatusOK { - return ResponsePayload{}, fmt.Errorf("unexpected status code: %d, response error(optional): %v", resp.StatusCode, res.Error) + if statusCode != http.StatusOK { + return ResponsePayload{}, fmt.Errorf("unexpected status code: %d, response error(optional): %v", statusCode, res.Error) } - // reset the body so it can be read again by deferred functions. - resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) return res, nil } diff --git a/providers/rpc/http_test.go b/providers/rpc/http_test.go index 3d837b08..c211ff62 100644 --- a/providers/rpc/http_test.go +++ b/providers/rpc/http_test.go @@ -15,9 +15,9 @@ import ( ) func testRequest(method, reqURL string, body RequestPayload, headers http.Header) *http.Request { - var buf bytes.Buffer - _ = json.NewEncoder(&buf).Encode(body) - req, _ := http.NewRequestWithContext(context.Background(), method, reqURL, &buf) + buf := new(bytes.Buffer) + _ = json.NewEncoder(buf).Encode(body) + req, _ := http.NewRequestWithContext(context.Background(), method, reqURL, buf) req.Header = headers return req } @@ -97,7 +97,9 @@ func TestResponseKVS(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - kvs := responseKVS(tc.resp) + buf := new(bytes.Buffer) + _, _ = io.Copy(buf, tc.resp.Body) + kvs := responseKVS(tc.resp.StatusCode, tc.resp.Header, buf) if diff := cmp.Diff(kvs, tc.expected); diff != "" { t.Fatalf("requestKVS() mismatch (-want +got):\n%s", diff) } diff --git a/providers/rpc/logging.go b/providers/rpc/logging.go index df456ead..5e689c68 100644 --- a/providers/rpc/logging.go +++ b/providers/rpc/logging.go @@ -42,18 +42,15 @@ func requestKVS(req *http.Request) []interface{} { } // responseKVS returns a slice of key, value sets. Used for logging. -func responseKVS(resp *http.Response) []interface{} { +func responseKVS(statusCode int, headers http.Header, body *bytes.Buffer) []interface{} { var r responseDetails - if resp != nil && resp.Body != nil { + if body.Len() > 0 { var p ResponsePayload - reqBody, err := io.ReadAll(resp.Body) - if err == nil { - _ = json.Unmarshal(reqBody, &p) - } + _ = json.Unmarshal(body.Bytes(), &p) r = responseDetails{ - StatusCode: resp.StatusCode, + StatusCode: statusCode, Body: p, - Headers: resp.Header, + Headers: headers, } } diff --git a/providers/rpc/rpc.go b/providers/rpc/rpc.go index 96a09aa0..396c48ff 100644 --- a/providers/rpc/rpc.go +++ b/providers/rpc/rpc.go @@ -182,9 +182,9 @@ func (p *Provider) Open(ctx context.Context) error { return err } p.listenerURL = u - var buf bytes.Buffer - _ = json.NewEncoder(&buf).Encode(RequestPayload{}) - testReq, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), bytes.NewReader(buf.Bytes())) + buf := new(bytes.Buffer) + _ = json.NewEncoder(buf).Encode(RequestPayload{}) + testReq, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), buf) if err != nil { return err } @@ -198,8 +198,7 @@ func (p *Provider) Open(ctx context.Context) error { } // test that the consumer responses with the expected contract (ResponsePayload{}). - var res ResponsePayload - if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + if err := json.NewDecoder(resp.Body).Decode(&ResponsePayload{}); err != nil { return fmt.Errorf("issue with the rpc consumer response: %v", err) } @@ -292,19 +291,17 @@ func (p *Provider) process(ctx context.Context, rp RequestPayload) (ResponsePayl } // create the signature payload - // get the body and reset it as readers can only be read once. - body, err := io.ReadAll(req.Body) - if err != nil { - return ResponsePayload{}, err + reqBuf := new(bytes.Buffer) + if _, err := io.Copy(reqBuf, req.Body); err != nil { + return ResponsePayload{}, fmt.Errorf("failed to read request body: %w", err) } - req.Body = io.NopCloser(bytes.NewBuffer(body)) headersForSig := http.Header{} for _, h := range p.Opts.Signature.IncludedPayloadHeaders { if val := req.Header.Get(h); val != "" { headersForSig.Add(h, val) } } - sigPay := createSignaturePayload(body, headersForSig) + sigPay := createSignaturePayload(reqBuf.Bytes(), headersForSig) // sign the signature payload sigs, err := sign(sigPay, p.Opts.HMAC.Hashes, p.Opts.HMAC.PrefixSigDisabled) @@ -336,12 +333,16 @@ func (p *Provider) process(ctx context.Context, rp RequestPayload) (ResponsePayl return ResponsePayload{}, err } defer resp.Body.Close() + respBuf := new(bytes.Buffer) + if _, err := io.Copy(respBuf, resp.Body); err != nil { + return ResponsePayload{}, fmt.Errorf("failed to read response body: %w", err) + } // handle the response if resp.ContentLength > maxContentLenAllowed || resp.ContentLength < 0 { return ResponsePayload{}, fmt.Errorf("response body is too large: %d bytes, max allowed: %d bytes", resp.ContentLength, maxContentLenAllowed) } - respPayload, err := p.handleResponse(resp, kvs) + respPayload, err := p.handleResponse(resp.StatusCode, resp.Header, respBuf, kvs) if err != nil { return ResponsePayload{}, err }