Skip to content

Commit

Permalink
fix: backport Huma v2 example pointer fix from #148
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgtaylor committed Oct 23, 2023
1 parent 62c3193 commit dfa7ffa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
21 changes: 18 additions & 3 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,18 @@ func F(value float64) *float64 {
return &value
}

func deref(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}

// getTagValue returns a value of the schema's type for the given tag string.
// Uses JSON parsing if the schema is not a string.
func getTagValue(s *Schema, t reflect.Type, value string) (interface{}, error) {
// Special case: strings don't need quotes.
if s.Type == TypeString {
if s.Type == TypeString || (t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.String) {
return value, nil
}

Expand Down Expand Up @@ -105,11 +112,19 @@ func getTagValue(s *Schema, t reflect.Type, value string) (interface{}, error) {
tmp = reflect.Append(tmp, vv.Index(i).Elem().Convert(t.Elem()))
}
v = tmp.Interface()
} else if !tv.ConvertibleTo(t) {
} else if !tv.ConvertibleTo(deref(t)) {
return nil, fmt.Errorf("unable to convert %v to %v: %w", tv, t, ErrSchemaInvalid)
}

v = reflect.ValueOf(v).Convert(t).Interface()
converted := reflect.ValueOf(v).Convert(deref(t))
if t.Kind() == reflect.Ptr {
// Special case: if the field is a pointer, we need to get a pointer
// to the converted value.
tmp := reflect.New(t.Elem())
tmp.Elem().Set(converted)
converted = tmp
}
v = converted.Interface()
}

return v, nil
Expand Down
3 changes: 3 additions & 0 deletions schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,14 @@ func TestSchemaDefault(t *testing.T) {
func TestSchemaExample(t *testing.T) {
type Example struct {
Foo string `json:"foo" example:"ex"`
Bar *int64 `json:"bar" example:"5"`
}

s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, "ex", s.Properties["foo"].Example)
ex := int64(5)
assert.Equal(t, &ex, s.Properties["bar"].Example)
}

func TestSchemaNullable(t *testing.T) {
Expand Down

0 comments on commit dfa7ffa

Please sign in to comment.