From 1971cc66b13555e87d4ad9c962c980519c52a734 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Fri, 12 Apr 2024 17:00:16 -0700 Subject: [PATCH] fix: use errors.As to get StatusError, enabling wrapping --- error_test.go | 9 +++++++++ huma.go | 3 ++- huma_test.go | 17 +++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/error_test.go b/error_test.go index d407b3b7..c4c27585 100644 --- a/error_test.go +++ b/error_test.go @@ -2,6 +2,7 @@ package huma_test import ( "errors" + "fmt" "net/http" "net/http/httptest" "testing" @@ -101,3 +102,11 @@ func TestTransformError(t *testing.T) { require.Error(t, huma.WriteErr(api, ctx, 400, "bad request")) } + +func TestErrorAs(t *testing.T) { + err := fmt.Errorf("wrapped: %w", huma.Error400BadRequest("test")) + + var e huma.StatusError + require.ErrorAs(t, err, &e) + assert.Equal(t, 400, e.GetStatus()) +} diff --git a/huma.go b/huma.go index 178c6621..2fa126ba 100644 --- a/huma.go +++ b/huma.go @@ -1274,7 +1274,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) output, err := handler(ctx.Context(), &input) if err != nil { status := http.StatusInternalServerError - if se, ok := err.(StatusError); ok { + var se StatusError + if errors.As(err, &se) { status = se.GetStatus() } else { err = NewError(http.StatusInternalServerError, err.Error()) diff --git a/huma_test.go b/huma_test.go index 12532502..7c12efaa 100644 --- a/huma_test.go +++ b/huma_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "mime/multipart" "net" @@ -656,6 +657,22 @@ Content of example2.txt. assert.Equal(t, http.StatusForbidden, resp.Code) }, }, + { + Name: "handler-wrapped-error", + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/error", + }, func(ctx context.Context, input *struct{}) (*struct{}, error) { + return nil, fmt.Errorf("wrapped: %w", huma.Error403Forbidden("nope")) + }) + }, + Method: http.MethodGet, + URL: "/error", + Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusForbidden, resp.Code) + }, + }, { Name: "response-headers", Register: func(t *testing.T, api huma.API) {