Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use WriteErr function for handling error returned by handler #640

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
)

// ErrorDetailer returns error details for responses & debugging. This enables
Expand Down Expand Up @@ -262,22 +261,12 @@ func WriteErr(api API, ctx Context, status int, msg string, errs ...error) error
// If it was not modified then this is a no-op.
status = err.GetStatus()

ct, negotiateErr := api.Negotiate(ctx.Header("Accept"))
if negotiateErr != nil {
return negotiateErr
writeErr := writeResponse(api, ctx, status, "", err)
if writeErr != nil {
// If we can't write the error, log it so we know what happened.
fmt.Printf("could not write error: %s\n", writeErr)
}

if ctf, ok := err.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}

ctx.SetHeader("Content-Type", ct)
ctx.SetStatus(status)
tval, terr := api.Transform(ctx, strconv.Itoa(status), err)
if terr != nil {
return terr
}
return api.Marshal(ctx.BodyWriter(), ct, tval)
return writeErr
}

// Status304NotModified returns a 304. This is not really an error, but
Expand Down Expand Up @@ -372,4 +361,4 @@ func Error504GatewayTimeout(msg string, errs ...error) StatusError {
}

// ErrorFormatter is a function that formats an error message
var ErrorFormatter func(format string, a ...any) string = fmt.Sprintf
var ErrorFormatter = fmt.Sprintf
4 changes: 2 additions & 2 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestNegotiateError(t *testing.T) {

req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := humatest.NewContext(nil, req, resp)
ctx := humatest.NewContext(&huma.Operation{}, req, resp)
require.Error(t, huma.WriteErr(api, ctx, 400, "bad request"))
}

Expand All @@ -98,7 +98,7 @@ func TestTransformError(t *testing.T) {

req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := humatest.NewContext(nil, req, resp)
ctx := humatest.NewContext(&huma.Operation{}, req, resp)

require.Error(t, huma.WriteErr(api, ctx, 400, "bad request"))
}
Expand Down
78 changes: 46 additions & 32 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,41 @@
},
}

func writeResponse(api API, ctx Context, status int, ct string, body any) error {
if ct == "" {
// If no content type was provided, try to negotiate one with the client.
var err error
ct, err = api.Negotiate(ctx.Header("Accept"))
if err != nil {
notAccept := NewErrorWithContext(ctx, http.StatusNotAcceptable, "unable to marshal response", err)
if e := transformAndWrite(api, ctx, http.StatusNotAcceptable, "application/json", notAccept); e != nil {
return e
}
return err
}

Check warning on line 477 in huma.go

View check run for this annotation

Codecov / codecov/patch

huma.go#L477

Added line #L477 was not covered by tests

if ctf, ok := body.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}

ctx.SetHeader("Content-Type", ct)
}

if err := transformAndWrite(api, ctx, status, ct, body); err != nil {
return err
}
return nil
}

func writeResponseWithPanic(api API, ctx Context, status int, ct string, body any) {
if err := writeResponse(api, ctx, status, ct, body); err != nil {
panic(err)
}
}

// transformAndWrite is a utility function to transform and write a response.
// It is best-effort as the status code and headers may have already been sent.
func transformAndWrite(api API, ctx Context, status int, ct string, body any) {
func transformAndWrite(api API, ctx Context, status int, ct string, body any) error {
// Try to transform and then marshal/write the response.
// Status code was already sent, so just log the error if something fails,
// and do our best to stuff it into the body of the response.
Expand All @@ -474,17 +506,18 @@
ctx.BodyWriter().Write([]byte("error transforming response"))
// When including tval in the panic message, the server may become unresponsive for some time if the value is very large
// therefore, it has been removed from the panic message
panic(fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr))
return fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr)
}
ctx.SetStatus(status)
if status != http.StatusNoContent && status != http.StatusNotModified {
if merr := api.Marshal(ctx.BodyWriter(), ct, tval); merr != nil {
ctx.BodyWriter().Write([]byte("error marshaling response"))
// When including tval in the panic message, the server may become unresponsive for some time if the value is very large
// therefore, it has been removed from the panic message
panic(fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr))
return fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr)
}
}
return nil
}

func parseArrElement[T any](values []string, parse func(string) (T, error)) ([]T, error) {
Expand Down Expand Up @@ -962,7 +995,7 @@
if f.Type() == reflect.TypeOf(values) {
f.Set(reflect.ValueOf(values))
} else {
//Change element type to support slice of string subtypes (enums)
// Change element type to support slice of string subtypes (enums)
enumValues := reflect.New(f.Type()).Elem()
for _, val := range values {
enumVal := reflect.New(f.Type().Elem()).Elem()
Expand Down Expand Up @@ -1402,21 +1435,16 @@
}

status := http.StatusInternalServerError

// handle status error
var se StatusError
if errors.As(err, &se) {
status = se.GetStatus()
err = se
} else {
err = NewError(http.StatusInternalServerError, err.Error())
}

ct, _ := api.Negotiate(ctx.Header("Accept"))
if ctf, ok := err.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
writeResponseWithPanic(api, ctx, se.GetStatus(), "", se)
return
}

ctx.SetHeader("Content-Type", ct)
transformAndWrite(api, ctx, status, ct, err)
se = NewErrorWithContext(ctx, status, "unexpected error occurred", err)
writeResponseWithPanic(api, ctx, se.GetStatus(), "", se)
return
}

Expand All @@ -1441,7 +1469,8 @@
}
} else {
if f.Kind() == reflect.String && info.Name == "Content-Type" {
// Track custom content type.
// Track custom content type. This overrides any content negotiation
// that would happen when writing the response.
ct = f.String()
}
writeHeader(ctx.SetHeader, info, f)
Expand All @@ -1468,22 +1497,7 @@
return
}

// Only write a content type if one wasn't already written by the
// response headers handled above.
if ct == "" {
ct, err = api.Negotiate(ctx.Header("Accept"))
if err != nil {
WriteErr(api, ctx, http.StatusNotAcceptable, "unable to marshal response", err)
return
}
if ctf, ok := body.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}

ctx.SetHeader("Content-Type", ct)
}

transformAndWrite(api, ctx, status, ct, body)
writeResponseWithPanic(api, ctx, status, ct, body)
} else {
ctx.SetStatus(status)
}
Expand Down
16 changes: 16 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,22 @@ Content of example2.txt.
assert.Equal(t, http.StatusForbidden, resp.Code)
},
},
{
Name: "handler-generic-error",
Register: func(t *testing.T, api huma.API) {
huma.Register(api, huma.Operation{
Method: http.MethodGet,
Path: "/error",
}, func(ctx context.Context, input *struct{}) (*struct{}, error) {
return nil, errors.New("whoops")
})
},
Method: http.MethodGet,
URL: "/error",
Assert: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, resp.Code)
},
},
{
Name: "response-headers",
Register: func(t *testing.T, api huma.API) {
Expand Down
3 changes: 2 additions & 1 deletion openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,8 @@ type Operation struct {
// This is a convenience for handlers that return a fixed set of errors
// where you do not wish to provide each one as an OpenAPI response object.
// Each error specified here is expanded into a response object with the
// schema generated from the type returned by `huma.NewError()`.
// schema generated from the type returned by `huma.NewError()`
// or `huma.NewErrorWithContext`.
Errors []int `yaml:"-"`

// SkipValidateParams disables validation of path, query, and header
Expand Down
Loading