Skip to content

Commit

Permalink
fix: use errors.As to get StatusError, enabling wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgtaylor committed Apr 13, 2024
1 parent 864ecb4 commit 1971cc6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
9 changes: 9 additions & 0 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package huma_test

import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -101,3 +102,11 @@ func TestTransformError(t *testing.T) {

require.Error(t, huma.WriteErr(api, ctx, 400, "bad request"))
}

func TestErrorAs(t *testing.T) {
err := fmt.Errorf("wrapped: %w", huma.Error400BadRequest("test"))

var e huma.StatusError
require.ErrorAs(t, err, &e)
assert.Equal(t, 400, e.GetStatus())
}
3 changes: 2 additions & 1 deletion huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
output, err := handler(ctx.Context(), &input)
if err != nil {
status := http.StatusInternalServerError
if se, ok := err.(StatusError); ok {
var se StatusError
if errors.As(err, &se) {
status = se.GetStatus()
} else {
err = NewError(http.StatusInternalServerError, err.Error())
Expand Down
17 changes: 17 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net"
Expand Down Expand Up @@ -656,6 +657,22 @@ Content of example2.txt.
assert.Equal(t, http.StatusForbidden, resp.Code)
},
},
{
Name: "handler-wrapped-error",
Register: func(t *testing.T, api huma.API) {
huma.Register(api, huma.Operation{
Method: http.MethodGet,
Path: "/error",
}, func(ctx context.Context, input *struct{}) (*struct{}, error) {
return nil, fmt.Errorf("wrapped: %w", huma.Error403Forbidden("nope"))
})
},
Method: http.MethodGet,
URL: "/error",
Assert: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusForbidden, resp.Code)
},
},
{
Name: "response-headers",
Register: func(t *testing.T, api huma.API) {
Expand Down

0 comments on commit 1971cc6

Please sign in to comment.