Skip to content

Commit

Permalink
Merge pull request #242 from danielgtaylor/custom-schema-example
Browse files Browse the repository at this point in the history
fix: allow example/enum tags for custom schemas
  • Loading branch information
danielgtaylor authored Feb 15, 2024
2 parents ab77bc6 + 04887d7 commit 0e5a8bb
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 35 deletions.
9 changes: 5 additions & 4 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p

var example any
if e := f.Tag.Get("example"); e != "" {
example = jsonTagValue(f, f.Type, f.Tag.Get("example"))
example = jsonTagValue(registry, f.Type.Name(), pfi.Schema, f.Tag.Get("example"))
}

if def := f.Tag.Get("default"); def != "" {
Expand Down Expand Up @@ -190,13 +190,14 @@ func findResolvers(resolverType, t reflect.Type) *findResult[bool] {
}, nil)
}

func findDefaults(t reflect.Type) *findResult[any] {
func findDefaults(registry Registry, t reflect.Type) *findResult[any] {
return findInType(t, nil, func(sf reflect.StructField, i []int) any {
if d := sf.Tag.Get("default"); d != "" {
if sf.Type.Kind() == reflect.Pointer {
panic("pointers cannot have default values")
}
return jsonTagValue(sf, sf.Type, d)
s := registry.Schema(sf.Type, true, "")
return convertType(sf.Type.Name(), sf.Type, jsonTagValue(registry, sf.Name, s, d))
}
return nil
})
Expand Down Expand Up @@ -560,7 +561,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}

resolvers := findResolvers(resolverType, inputType)
defaults := findDefaults(inputType)
defaults := findDefaults(registry, inputType)

if op.Responses == nil {
op.Responses = map[string]*Response{}
Expand Down
3 changes: 3 additions & 0 deletions registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ func (r *mapRegistry) Schema(t reflect.Type, allowRef bool, hint string) *Schema
}

func (r *mapRegistry) SchemaFromRef(ref string) *Schema {
if !strings.HasPrefix(ref, r.prefix) {
return nil
}
return r.schemas[ref[len(r.prefix):]]
}

Expand Down
121 changes: 96 additions & 25 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,26 +280,64 @@ func floatTag(f reflect.StructField, tag string) *float64 {
return nil
}

func jsonTagValue(f reflect.StructField, t reflect.Type, value string) any {
// Special case: strings don't need quotes.
if t.Kind() == reflect.String || (t.Kind() == reflect.Pointer && t.Elem().Kind() == reflect.String) {
return value
// ensureType panics if the given value does not match the JSON Schema type.
func ensureType(r Registry, fieldName string, s *Schema, value string, v any) {
if s.Ref != "" {
s = r.SchemaFromRef(s.Ref)
if s == nil {
// We may not have access to this type, e.g. custom schema provided
// by the user with remote refs. Skip validation.
return
}
}

// Special case: array of strings with comma-separated values and no quotes.
if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.String && value[0] != '[' {
values := []string{}
for _, s := range strings.Split(value, ",") {
values = append(values, strings.TrimSpace(s))
switch s.Type {
case TypeBoolean:
if _, ok := v.(bool); !ok {
panic(fmt.Errorf("invalid boolean tag value '%s' for field '%s': %w", value, fieldName, ErrSchemaInvalid))
}
case TypeInteger, TypeNumber:
if _, ok := v.(float64); !ok {
panic(fmt.Errorf("invalid number tag value '%s' for field '%s': %w", value, fieldName, ErrSchemaInvalid))
}
return values
}

var v any
if err := json.Unmarshal([]byte(value), &v); err != nil {
panic(fmt.Errorf("invalid tag for field '%s': %w", f.Name, err))
if s.Type == TypeInteger {
if v.(float64) != float64(int(v.(float64))) {
panic(fmt.Errorf("invalid integer tag value '%s' for field '%s': %w", value, fieldName, ErrSchemaInvalid))
}
}
case TypeString:
if _, ok := v.(string); !ok {
panic(fmt.Errorf("invalid string tag value '%s' for field '%s': %w", value, fieldName, ErrSchemaInvalid))
}
case TypeArray:
if _, ok := v.([]any); !ok {
panic(fmt.Errorf("invalid array tag value '%s' for field '%s': %w", value, fieldName, ErrSchemaInvalid))
}

if s.Items != nil {
for i, item := range v.([]any) {
b, _ := json.Marshal(item)
ensureType(r, fieldName+"["+strconv.Itoa(i)+"]", s.Items, string(b), item)
}
}
case TypeObject:
if _, ok := v.(map[string]any); !ok {
panic(fmt.Errorf("invalid object tag value '%s' for field '%s': %w", value, fieldName, ErrSchemaInvalid))
}

for name, prop := range s.Properties {
if val, ok := v.(map[string]any)[name]; ok {
b, _ := json.Marshal(val)
ensureType(r, fieldName+"."+name, prop, string(b), val)
}
}
}
}

// convertType panics if the given value does not match or cannot be converted
// to the field's Go type.
func convertType(fieldName string, t reflect.Type, v any) any {
vv := reflect.ValueOf(v)
tv := reflect.TypeOf(v)
if v != nil && tv != t {
Expand All @@ -310,14 +348,14 @@ func jsonTagValue(f reflect.StructField, t reflect.Type, value string) any {
tmp := reflect.MakeSlice(t, 0, vv.Len())
for i := 0; i < vv.Len(); i++ {
if !vv.Index(i).Elem().Type().ConvertibleTo(t.Elem()) {
panic(fmt.Errorf("unable to convert %v to %v for field '%s': %w", vv.Index(i).Interface(), t.Elem(), f.Name, ErrSchemaInvalid))
panic(fmt.Errorf("unable to convert %v to %v for field '%s': %w", vv.Index(i).Interface(), t.Elem(), fieldName, ErrSchemaInvalid))
}

tmp = reflect.Append(tmp, vv.Index(i).Elem().Convert(t.Elem()))
}
v = tmp.Interface()
} else if !tv.ConvertibleTo(deref(t)) {
panic(fmt.Errorf("unable to convert %v to %v for field '%s': %w", tv, t, f.Name, ErrSchemaInvalid))
panic(fmt.Errorf("unable to convert %v to %v for field '%s': %w", tv, t, fieldName, ErrSchemaInvalid))
}

converted := reflect.ValueOf(v).Convert(deref(t))
Expand All @@ -330,16 +368,47 @@ func jsonTagValue(f reflect.StructField, t reflect.Type, value string) any {
}
v = converted.Interface()
}
return v
}

func jsonTagValue(r Registry, fieldName string, s *Schema, value string) any {
if s.Ref != "" {
s = r.SchemaFromRef(s.Ref)
if s == nil {
return nil
}
}

// Special case: strings don't need quotes.
if s.Type == TypeString {
return value
}

// Special case: array of strings with comma-separated values and no quotes.
if s.Type == TypeArray && s.Items != nil && s.Items.Type == TypeString && value[0] != '[' {
values := []string{}
for _, s := range strings.Split(value, ",") {
values = append(values, strings.TrimSpace(s))
}
return values
}

var v any
if err := json.Unmarshal([]byte(value), &v); err != nil {
panic(fmt.Errorf("invalid %s tag value '%s' for field '%s': %w", s.Type, value, fieldName, err))
}

ensureType(r, fieldName, s, value, v)

return v
}

// jsonTag returns a value of the schema's type for the given tag string.
// Uses JSON parsing if the schema is not a string.
func jsonTag(f reflect.StructField, name string) any {
func jsonTag(r Registry, f reflect.StructField, s *Schema, name string) any {
t := f.Type
if value := f.Tag.Get(name); value != "" {
return jsonTagValue(f, t, value)
return convertType(f.Name, t, jsonTagValue(r, f.Name, s, value))
}
return nil
}
Expand Down Expand Up @@ -378,20 +447,22 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch
if enc := f.Tag.Get("encoding"); enc != "" {
fs.ContentEncoding = enc
}
fs.Default = jsonTag(f, "default")
fs.Default = jsonTag(registry, f, fs, "default")

if e := jsonTag(f, "example"); e != nil {
fs.Examples = []any{e}
if value := f.Tag.Get("example"); value != "" {
if e := jsonTagValue(registry, f.Name, fs, value); e != nil {
fs.Examples = []any{e}
}
}

if enum := f.Tag.Get("enum"); enum != "" {
fType := f.Type
if fs.Type == TypeArray {
fType = fType.Elem()
s := fs
if s.Type == TypeArray {
s = s.Items
}
enumValues := []any{}
for _, e := range strings.Split(enum, ",") {
enumValues = append(enumValues, jsonTagValue(f, fType, e))
enumValues = append(enumValues, jsonTagValue(registry, f.Name, s, e))
}
if fs.Type == TypeArray {
fs.Items.Enum = enumValues
Expand Down
Loading

0 comments on commit 0e5a8bb

Please sign in to comment.