diff --git a/error.go b/error.go index ec581ad2..7ac25d28 100644 --- a/error.go +++ b/error.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/http" - "strconv" ) // ErrorDetailer returns error details for responses & debugging. This enables @@ -262,22 +261,12 @@ func WriteErr(api API, ctx Context, status int, msg string, errs ...error) error // If it was not modified then this is a no-op. status = err.GetStatus() - ct, negotiateErr := api.Negotiate(ctx.Header("Accept")) - if negotiateErr != nil { - return negotiateErr + writeErr := writeResponse(api, ctx, status, "", err) + if writeErr != nil { + // If we can't write the error, log it so we know what happened. + fmt.Printf("could not write error: %s\n", writeErr) } - - if ctf, ok := err.(ContentTypeFilter); ok { - ct = ctf.ContentType(ct) - } - - ctx.SetHeader("Content-Type", ct) - ctx.SetStatus(status) - tval, terr := api.Transform(ctx, strconv.Itoa(status), err) - if terr != nil { - return terr - } - return api.Marshal(ctx.BodyWriter(), ct, tval) + return writeErr } // Status304NotModified returns a 304. This is not really an error, but @@ -372,4 +361,4 @@ func Error504GatewayTimeout(msg string, errs ...error) StatusError { } // ErrorFormatter is a function that formats an error message -var ErrorFormatter func(format string, a ...any) string = fmt.Sprintf +var ErrorFormatter = fmt.Sprintf diff --git a/error_test.go b/error_test.go index 43611af8..f7de497d 100644 --- a/error_test.go +++ b/error_test.go @@ -83,7 +83,7 @@ func TestNegotiateError(t *testing.T) { req, _ := http.NewRequest("GET", "/", nil) resp := httptest.NewRecorder() - ctx := humatest.NewContext(nil, req, resp) + ctx := humatest.NewContext(&huma.Operation{}, req, resp) require.Error(t, huma.WriteErr(api, ctx, 400, "bad request")) } @@ -98,7 +98,7 @@ func TestTransformError(t *testing.T) { req, _ := http.NewRequest("GET", "/", nil) resp := httptest.NewRecorder() - ctx := humatest.NewContext(nil, req, resp) + ctx := humatest.NewContext(&huma.Operation{}, req, resp) require.Error(t, huma.WriteErr(api, ctx, 400, "bad request")) } diff --git a/huma.go b/huma.go index bd7d5458..1c0f3f4f 100644 --- a/huma.go +++ b/huma.go @@ -464,9 +464,41 @@ var bufPool = sync.Pool{ }, } +func writeResponse(api API, ctx Context, status int, ct string, body any) error { + if ct == "" { + // If no content type was provided, try to negotiate one with the client. + var err error + ct, err = api.Negotiate(ctx.Header("Accept")) + if err != nil { + notAccept := NewErrorWithContext(ctx, http.StatusNotAcceptable, "unable to marshal response", err) + if e := transformAndWrite(api, ctx, http.StatusNotAcceptable, "application/json", notAccept); e != nil { + return e + } + return err + } + + if ctf, ok := body.(ContentTypeFilter); ok { + ct = ctf.ContentType(ct) + } + + ctx.SetHeader("Content-Type", ct) + } + + if err := transformAndWrite(api, ctx, status, ct, body); err != nil { + return err + } + return nil +} + +func writeResponseWithPanic(api API, ctx Context, status int, ct string, body any) { + if err := writeResponse(api, ctx, status, ct, body); err != nil { + panic(err) + } +} + // transformAndWrite is a utility function to transform and write a response. // It is best-effort as the status code and headers may have already been sent. -func transformAndWrite(api API, ctx Context, status int, ct string, body any) { +func transformAndWrite(api API, ctx Context, status int, ct string, body any) error { // Try to transform and then marshal/write the response. // Status code was already sent, so just log the error if something fails, // and do our best to stuff it into the body of the response. @@ -475,7 +507,7 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) { ctx.BodyWriter().Write([]byte("error transforming response")) // When including tval in the panic message, the server may become unresponsive for some time if the value is very large // therefore, it has been removed from the panic message - panic(fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr)) + return fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr) } ctx.SetStatus(status) if status != http.StatusNoContent && status != http.StatusNotModified { @@ -483,9 +515,10 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) { ctx.BodyWriter().Write([]byte("error marshaling response")) // When including tval in the panic message, the server may become unresponsive for some time if the value is very large // therefore, it has been removed from the panic message - panic(fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr)) + return fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr) } } + return nil } func parseArrElement[T any](values []string, parse func(string) (T, error)) ([]T, error) { @@ -963,7 +996,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) if f.Type() == reflect.TypeOf(values) { f.Set(reflect.ValueOf(values)) } else { - //Change element type to support slice of string subtypes (enums) + // Change element type to support slice of string subtypes (enums) enumValues := reflect.New(f.Type()).Elem() for _, val := range values { enumVal := reflect.New(f.Type().Elem()).Elem() @@ -1403,21 +1436,16 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } status := http.StatusInternalServerError + + // handle status error var se StatusError if errors.As(err, &se) { - status = se.GetStatus() - err = se - } else { - err = NewError(http.StatusInternalServerError, err.Error()) - } - - ct, _ := api.Negotiate(ctx.Header("Accept")) - if ctf, ok := err.(ContentTypeFilter); ok { - ct = ctf.ContentType(ct) + writeResponseWithPanic(api, ctx, se.GetStatus(), "", se) + return } - ctx.SetHeader("Content-Type", ct) - transformAndWrite(api, ctx, status, ct, err) + se = NewErrorWithContext(ctx, status, "unexpected error occurred", err) + writeResponseWithPanic(api, ctx, se.GetStatus(), "", se) return } @@ -1442,7 +1470,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } } else { if f.Kind() == reflect.String && info.Name == "Content-Type" { - // Track custom content type. + // Track custom content type. This overrides any content negotiation + // that would happen when writing the response. ct = f.String() } writeHeader(ctx.SetHeader, info, f) @@ -1469,22 +1498,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) return } - // Only write a content type if one wasn't already written by the - // response headers handled above. - if ct == "" { - ct, err = api.Negotiate(ctx.Header("Accept")) - if err != nil { - WriteErr(api, ctx, http.StatusNotAcceptable, "unable to marshal response", err) - return - } - if ctf, ok := body.(ContentTypeFilter); ok { - ct = ctf.ContentType(ct) - } - - ctx.SetHeader("Content-Type", ct) - } - - transformAndWrite(api, ctx, status, ct, body) + writeResponseWithPanic(api, ctx, status, ct, body) } else { ctx.SetStatus(status) } diff --git a/huma_test.go b/huma_test.go index 01ced429..8ca4c292 100644 --- a/huma_test.go +++ b/huma_test.go @@ -1376,6 +1376,22 @@ Content of example2.txt. assert.Equal(t, http.StatusForbidden, resp.Code) }, }, + { + Name: "handler-generic-error", + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/error", + }, func(ctx context.Context, input *struct{}) (*struct{}, error) { + return nil, errors.New("whoops") + }) + }, + Method: http.MethodGet, + URL: "/error", + Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusInternalServerError, resp.Code) + }, + }, { Name: "response-headers", Register: func(t *testing.T, api huma.API) { diff --git a/openapi.go b/openapi.go index 0518dc6a..8149daba 100644 --- a/openapi.go +++ b/openapi.go @@ -863,7 +863,8 @@ type Operation struct { // This is a convenience for handlers that return a fixed set of errors // where you do not wish to provide each one as an OpenAPI response object. // Each error specified here is expanded into a response object with the - // schema generated from the type returned by `huma.NewError()`. + // schema generated from the type returned by `huma.NewError()` + // or `huma.NewErrorWithContext`. Errors []int `yaml:"-"` // SkipValidateParams disables validation of path, query, and header