Skip to content

Commit

Permalink
refactor: response writing
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgtaylor committed Nov 12, 2024
1 parent b66498a commit 0cb7a89
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 68 deletions.
46 changes: 5 additions & 41 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
)

// ErrorDetailer returns error details for responses & debugging. This enables
Expand Down Expand Up @@ -262,47 +261,12 @@ func WriteErr(api API, ctx Context, status int, msg string, errs ...error) error
// If it was not modified then this is a no-op.
status = err.GetStatus()

ct, negotiateErr := api.Negotiate(ctx.Header("Accept"))
if negotiateErr != nil {
return negotiateErr
writeErr := writeResponse(api, ctx, status, "", err)
if writeErr != nil {
// If we can't write the error, log it so we know what happened.
fmt.Printf("could not write error: %s\n", writeErr)
}

if ctf, ok := err.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}

ctx.SetHeader("Content-Type", ct)
ctx.SetStatus(status)
tval, terr := api.Transform(ctx, strconv.Itoa(status), err)
if terr != nil {
return terr
}
return api.Marshal(ctx.BodyWriter(), ct, tval)
}

func writeStatusError(api API, ctx Context, err StatusError) error {
ct, negotiateErr := api.Negotiate(ctx.Header("Accept"))
if negotiateErr != nil {
return fmt.Errorf("failed to write status error: %w", negotiateErr)
}
if ctf, ok := err.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}
ctx.SetHeader("Content-Type", ct)

status := err.GetStatus()
ctx.SetStatus(status)

// If request accept no output, just set the status code and return.
if status == http.StatusNoContent || status == http.StatusNotModified {
return nil
}

tval, terr := api.Transform(ctx, strconv.Itoa(status), err)
if terr != nil {
return terr
}
return api.Marshal(ctx.BodyWriter(), ct, tval)
return writeErr
}

// Status304NotModified returns a 304. This is not really an error, but
Expand Down
4 changes: 2 additions & 2 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestNegotiateError(t *testing.T) {

req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := humatest.NewContext(nil, req, resp)
ctx := humatest.NewContext(&huma.Operation{}, req, resp)
require.Error(t, huma.WriteErr(api, ctx, 400, "bad request"))
}

Expand All @@ -98,7 +98,7 @@ func TestTransformError(t *testing.T) {

req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := humatest.NewContext(nil, req, resp)
ctx := humatest.NewContext(&huma.Operation{}, req, resp)

require.Error(t, huma.WriteErr(api, ctx, 400, "bad request"))
}
Expand Down
78 changes: 53 additions & 25 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,52 @@ var bufPool = sync.Pool{
},
}

func writeResponse(api API, ctx Context, status int, ct string, body any) error {
if ct == "" {
var err error
ct, err = getContentType(api, ctx, body)
if err != nil {
// Couldn't negotiate a content type, so return an error. This is best
// effort and we default to JSON. This prevents loops that would result
// from calling `WriteErr`.
status := http.StatusInternalServerError
if se, ok := err.(StatusError); ok {
status = se.GetStatus()
}
if err := transformAndWrite(api, ctx, status, "application/json", err); err != nil {
return err
}
}
}

if err := transformAndWrite(api, ctx, status, ct, body); err != nil {
return err
}
return nil
}

func writeResponseWithPanic(api API, ctx Context, status int, ct string, body any) {
if err := writeResponse(api, ctx, status, ct, body); err != nil {
panic(err)
}
}

func getContentType(api API, ctx Context, body any) (string, error) {
ct, err := api.Negotiate(ctx.Header("Accept"))
if err != nil {
return "", NewErrorWithContext(ctx, http.StatusNotAcceptable, "unable to marshal response", err)
}
if ctf, ok := body.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}

ctx.SetHeader("Content-Type", ct)
return ct, nil
}

// transformAndWrite is a utility function to transform and write a response.
// It is best-effort as the status code and headers may have already been sent.
func transformAndWrite(api API, ctx Context, status int, ct string, body any) {
func transformAndWrite(api API, ctx Context, status int, ct string, body any) error {
// Try to transform and then marshal/write the response.
// Status code was already sent, so just log the error if something fails,
// and do our best to stuff it into the body of the response.
Expand All @@ -474,17 +517,18 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) {
ctx.BodyWriter().Write([]byte("error transforming response"))
// When including tval in the panic message, the server may become unresponsive for some time if the value is very large
// therefore, it has been removed from the panic message
panic(fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr))
return fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr)
}
ctx.SetStatus(status)
if status != http.StatusNoContent && status != http.StatusNotModified {
if merr := api.Marshal(ctx.BodyWriter(), ct, tval); merr != nil {
ctx.BodyWriter().Write([]byte("error marshaling response"))
// When including tval in the panic message, the server may become unresponsive for some time if the value is very large
// therefore, it has been removed from the panic message
panic(fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr))
return fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr)
}
}
return nil
}

func parseArrElement[T any](values []string, parse func(string) (T, error)) ([]T, error) {
Expand Down Expand Up @@ -1406,14 +1450,12 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
// handle status error
var se StatusError
if errors.As(err, &se) {
writeStatusError(api, ctx, se)
writeResponseWithPanic(api, ctx, se.GetStatus(), "", se)
return
}

if err := WriteErr(api, ctx, status, "unexpected error occurred", err); err != nil {
ctx.BodyWriter().Write([]byte("internal server error"))
panic(fmt.Errorf("failed to write error response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, err))
}
se = NewErrorWithContext(ctx, status, "unexpected error occurred", err)
writeResponseWithPanic(api, ctx, se.GetStatus(), "", se)
return
}

Expand All @@ -1438,7 +1480,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}
} else {
if f.Kind() == reflect.String && info.Name == "Content-Type" {
// Track custom content type.
// Track custom content type. This overrides any content negotiation
// that would happen when writing the response.
ct = f.String()
}
writeHeader(ctx.SetHeader, info, f)
Expand All @@ -1465,22 +1508,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
return
}

// Only write a content type if one wasn't already written by the
// response headers handled above.
if ct == "" {
ct, err = api.Negotiate(ctx.Header("Accept"))
if err != nil {
WriteErr(api, ctx, http.StatusNotAcceptable, "unable to marshal response", err)
return
}
if ctf, ok := body.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}

ctx.SetHeader("Content-Type", ct)
}

transformAndWrite(api, ctx, status, ct, body)
writeResponseWithPanic(api, ctx, status, ct, body)
} else {
ctx.SetStatus(status)
}
Expand Down
16 changes: 16 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,22 @@ Content of example2.txt.
assert.Equal(t, http.StatusForbidden, resp.Code)
},
},
{
Name: "handler-generic-error",
Register: func(t *testing.T, api huma.API) {
huma.Register(api, huma.Operation{
Method: http.MethodGet,
Path: "/error",
}, func(ctx context.Context, input *struct{}) (*struct{}, error) {
return nil, fmt.Errorf("whoops")
})
},
Method: http.MethodGet,
URL: "/error",
Assert: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, resp.Code)
},
},
{
Name: "response-headers",
Register: func(t *testing.T, api huma.API) {

Check failure on line 1386 in huma_test.go

View workflow job for this annotation

GitHub Actions / Build & Test (1.22)

fmt.Errorf can be replaced with errors.New (perfsprint)
Expand Down

0 comments on commit 0cb7a89

Please sign in to comment.