diff --git a/error.go b/error.go index dbdb1043..7ac25d28 100644 --- a/error.go +++ b/error.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/http" - "strconv" ) // ErrorDetailer returns error details for responses & debugging. This enables @@ -262,47 +261,12 @@ func WriteErr(api API, ctx Context, status int, msg string, errs ...error) error // If it was not modified then this is a no-op. status = err.GetStatus() - ct, negotiateErr := api.Negotiate(ctx.Header("Accept")) - if negotiateErr != nil { - return negotiateErr + writeErr := writeResponse(api, ctx, status, "", err) + if writeErr != nil { + // If we can't write the error, log it so we know what happened. + fmt.Printf("could not write error: %s\n", writeErr) } - - if ctf, ok := err.(ContentTypeFilter); ok { - ct = ctf.ContentType(ct) - } - - ctx.SetHeader("Content-Type", ct) - ctx.SetStatus(status) - tval, terr := api.Transform(ctx, strconv.Itoa(status), err) - if terr != nil { - return terr - } - return api.Marshal(ctx.BodyWriter(), ct, tval) -} - -func writeStatusError(api API, ctx Context, err StatusError) error { - ct, negotiateErr := api.Negotiate(ctx.Header("Accept")) - if negotiateErr != nil { - return fmt.Errorf("failed to write status error: %w", negotiateErr) - } - if ctf, ok := err.(ContentTypeFilter); ok { - ct = ctf.ContentType(ct) - } - ctx.SetHeader("Content-Type", ct) - - status := err.GetStatus() - ctx.SetStatus(status) - - // If request accept no output, just set the status code and return. - if status == http.StatusNoContent || status == http.StatusNotModified { - return nil - } - - tval, terr := api.Transform(ctx, strconv.Itoa(status), err) - if terr != nil { - return terr - } - return api.Marshal(ctx.BodyWriter(), ct, tval) + return writeErr } // Status304NotModified returns a 304. This is not really an error, but diff --git a/error_test.go b/error_test.go index 43611af8..f7de497d 100644 --- a/error_test.go +++ b/error_test.go @@ -83,7 +83,7 @@ func TestNegotiateError(t *testing.T) { req, _ := http.NewRequest("GET", "/", nil) resp := httptest.NewRecorder() - ctx := humatest.NewContext(nil, req, resp) + ctx := humatest.NewContext(&huma.Operation{}, req, resp) require.Error(t, huma.WriteErr(api, ctx, 400, "bad request")) } @@ -98,7 +98,7 @@ func TestTransformError(t *testing.T) { req, _ := http.NewRequest("GET", "/", nil) resp := httptest.NewRecorder() - ctx := humatest.NewContext(nil, req, resp) + ctx := humatest.NewContext(&huma.Operation{}, req, resp) require.Error(t, huma.WriteErr(api, ctx, 400, "bad request")) } diff --git a/huma.go b/huma.go index 20d24846..33348c62 100644 --- a/huma.go +++ b/huma.go @@ -463,9 +463,52 @@ var bufPool = sync.Pool{ }, } +func writeResponse(api API, ctx Context, status int, ct string, body any) error { + if ct == "" { + var err error + ct, err = getContentType(api, ctx, body) + if err != nil { + // Couldn't negotiate a content type, so return an error. This is best + // effort and we default to JSON. This prevents loops that would result + // from calling `WriteErr`. + status := http.StatusInternalServerError + if se, ok := err.(StatusError); ok { + status = se.GetStatus() + } + if err := transformAndWrite(api, ctx, status, "application/json", err); err != nil { + return err + } + } + } + + if err := transformAndWrite(api, ctx, status, ct, body); err != nil { + return err + } + return nil +} + +func writeResponseWithPanic(api API, ctx Context, status int, ct string, body any) { + if err := writeResponse(api, ctx, status, ct, body); err != nil { + panic(err) + } +} + +func getContentType(api API, ctx Context, body any) (string, error) { + ct, err := api.Negotiate(ctx.Header("Accept")) + if err != nil { + return "", NewErrorWithContext(ctx, http.StatusNotAcceptable, "unable to marshal response", err) + } + if ctf, ok := body.(ContentTypeFilter); ok { + ct = ctf.ContentType(ct) + } + + ctx.SetHeader("Content-Type", ct) + return ct, nil +} + // transformAndWrite is a utility function to transform and write a response. // It is best-effort as the status code and headers may have already been sent. -func transformAndWrite(api API, ctx Context, status int, ct string, body any) { +func transformAndWrite(api API, ctx Context, status int, ct string, body any) error { // Try to transform and then marshal/write the response. // Status code was already sent, so just log the error if something fails, // and do our best to stuff it into the body of the response. @@ -474,7 +517,7 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) { ctx.BodyWriter().Write([]byte("error transforming response")) // When including tval in the panic message, the server may become unresponsive for some time if the value is very large // therefore, it has been removed from the panic message - panic(fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr)) + return fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr) } ctx.SetStatus(status) if status != http.StatusNoContent && status != http.StatusNotModified { @@ -482,9 +525,10 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) { ctx.BodyWriter().Write([]byte("error marshaling response")) // When including tval in the panic message, the server may become unresponsive for some time if the value is very large // therefore, it has been removed from the panic message - panic(fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr)) + return fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr) } } + return nil } func parseArrElement[T any](values []string, parse func(string) (T, error)) ([]T, error) { @@ -1406,14 +1450,12 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) // handle status error var se StatusError if errors.As(err, &se) { - writeStatusError(api, ctx, se) + writeResponseWithPanic(api, ctx, se.GetStatus(), "", se) return } - if err := WriteErr(api, ctx, status, "unexpected error occurred", err); err != nil { - ctx.BodyWriter().Write([]byte("internal server error")) - panic(fmt.Errorf("failed to write error response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, err)) - } + se = NewErrorWithContext(ctx, status, "unexpected error occurred", err) + writeResponseWithPanic(api, ctx, se.GetStatus(), "", se) return } @@ -1438,7 +1480,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } } else { if f.Kind() == reflect.String && info.Name == "Content-Type" { - // Track custom content type. + // Track custom content type. This overrides any content negotiation + // that would happen when writing the response. ct = f.String() } writeHeader(ctx.SetHeader, info, f) @@ -1465,22 +1508,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) return } - // Only write a content type if one wasn't already written by the - // response headers handled above. - if ct == "" { - ct, err = api.Negotiate(ctx.Header("Accept")) - if err != nil { - WriteErr(api, ctx, http.StatusNotAcceptable, "unable to marshal response", err) - return - } - if ctf, ok := body.(ContentTypeFilter); ok { - ct = ctf.ContentType(ct) - } - - ctx.SetHeader("Content-Type", ct) - } - - transformAndWrite(api, ctx, status, ct, body) + writeResponseWithPanic(api, ctx, status, ct, body) } else { ctx.SetStatus(status) } diff --git a/huma_test.go b/huma_test.go index 2fe83521..7934f316 100644 --- a/huma_test.go +++ b/huma_test.go @@ -1365,6 +1365,22 @@ Content of example2.txt. assert.Equal(t, http.StatusForbidden, resp.Code) }, }, + { + Name: "handler-generic-error", + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/error", + }, func(ctx context.Context, input *struct{}) (*struct{}, error) { + return nil, fmt.Errorf("whoops") + }) + }, + Method: http.MethodGet, + URL: "/error", + Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusInternalServerError, resp.Code) + }, + }, { Name: "response-headers", Register: func(t *testing.T, api huma.API) {