From 258794986dae91253533296aa34af8831d51d5df Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Tue, 8 Oct 2024 22:15:26 -0700 Subject: [PATCH] fix: marshal empty security object --- huma_test.go | 18 ++++++++++++++++++ openapi.go | 25 ++++++++++++++++++++++--- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/huma_test.go b/huma_test.go index 7fa0f8d5..f49cbb2f 100644 --- a/huma_test.go +++ b/huma_test.go @@ -1792,6 +1792,24 @@ Content of example2.txt. URL: "/one-of", Body: `[{"foo": "first"}, {"foo": "second"}]`, }, + { + Name: "security-override-public", + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/public", + Security: []map[string][]string{}, // No security for this call! + }, func(ctx context.Context, input *struct{}) (*struct{}, error) { + return nil, nil + }) + // Note: the empty security object should be serialized as an empty + // array in the OpenAPI document. + b, _ := api.OpenAPI().Paths["/public"].Get.MarshalJSON() + assert.Contains(t, string(b), `"security":[]`) + }, + Method: http.MethodGet, + URL: "/public", + }, } { t.Run(feature.Name, func(t *testing.T) { r := chi.NewRouter() diff --git a/openapi.go b/openapi.go index 4a8106a6..0518dc6a 100644 --- a/openapi.go +++ b/openapi.go @@ -51,13 +51,32 @@ func isEmptyValue(v reflect.Value) bool { return false } +// isNilValue returns true if the given value is nil. +func isNilValue(v any) bool { + if v == nil { + return true + } + + // Nil is typed and may not always match above, so for some types we can + // use reflection instead. This is a bit slower, but works. + // https://www.calhoun.io/when-nil-isnt-equal-to-nil/ + // https://go.dev/doc/faq#nil_error + vv := reflect.ValueOf(v) + switch vv.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return vv.IsNil() + } + + return false +} + // marshalJSON marshals a list of fields and their values into JSON. It supports // inlined extensions. func marshalJSON(fields []jsonFieldInfo, extensions map[string]any) ([]byte, error) { value := make(map[string]any, len(extensions)+len(fields)) for _, v := range fields { - if v.omit == omitNil && v.value == nil { + if v.omit == omitNil && isNilValue(v.value) { continue } if v.omit == omitEmpty { @@ -956,7 +975,7 @@ func (o *Operation) MarshalJSON() ([]byte, error) { {"responses", o.Responses, omitEmpty}, {"callbacks", o.Callbacks, omitEmpty}, {"deprecated", o.Deprecated, omitEmpty}, - {"security", o.Security, omitEmpty}, + {"security", o.Security, omitNil}, {"servers", o.Servers, omitEmpty}, }, o.Extensions) } @@ -1507,7 +1526,7 @@ func (o *OpenAPI) MarshalJSON() ([]byte, error) { {"paths", o.Paths, omitEmpty}, {"webhooks", o.Webhooks, omitEmpty}, {"components", o.Components, omitEmpty}, - {"security", o.Security, omitEmpty}, + {"security", o.Security, omitNil}, {"tags", o.Tags, omitEmpty}, {"externalDocs", o.ExternalDocs, omitEmpty}, }, o.Extensions)