Skip to content

Commit

Permalink
Merge pull request #131 from danielgtaylor/any-all-one-of
Browse files Browse the repository at this point in the history
feat: oneOf, anyOf, allOf, not schema support
  • Loading branch information
danielgtaylor authored Sep 22, 2023
2 parents 8ff720b + 9c157ac commit c81c4ba
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 26 deletions.
59 changes: 34 additions & 25 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,12 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}
inputParams := findParams(registry, &op, inputType)
inputBodyIndex := -1
var inSchema *Schema
if f, ok := inputType.FieldByName("Body"); ok {
inputBodyIndex = f.Index[0]
inSchema = registry.Schema(f.Type, true, getHint(inputType, f.Name, op.OperationID+"Request"))
op.RequestBody = &RequestBody{
Content: map[string]*MediaType{
"application/json": {
Schema: inSchema,
Schema: registry.Schema(f.Type, true, getHint(inputType, f.Name, op.OperationID+"Request")),
},
},
}
Expand All @@ -391,6 +389,12 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
if f, ok := inputType.FieldByName("RawBody"); ok {
rawBodyIndex = f.Index[0]
}

var inSchema *Schema
if op.RequestBody != nil && op.RequestBody.Content != nil && op.RequestBody.Content["application/json"] != nil && op.RequestBody.Content["application/json"].Schema != nil {
inSchema = op.RequestBody.Content["application/json"].Schema
}

resolvers := findResolvers(resolverType, inputType)
defaults := findDefaults(inputType)

Expand Down Expand Up @@ -628,7 +632,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
})

// Read input body if defined.
if inputBodyIndex != -1 {
if inputBodyIndex != -1 || rawBodyIndex != -1 {
if op.BodyReadTimeout > 0 {
ctx.SetReadDeadline(time.Now().Add(op.BodyReadTimeout))
} else if op.BodyReadTimeout < 0 {
Expand Down Expand Up @@ -676,7 +680,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}

if len(body) == 0 {
kind := v.Field(inputBodyIndex).Kind()
kind := reflect.Slice // []byte by default for raw body
if inputBodyIndex != -1 {
kind = v.Field(inputBodyIndex).Kind()
}
if kind != reflect.Ptr && kind != reflect.Interface {
buf.Reset()
bufPool.Put(buf)
Expand Down Expand Up @@ -711,28 +718,30 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}
}

// We need to get the body into the correct type now that it has been
// validated. Benchmarks on Go 1.20 show that using `json.Unmarshal` a
// second time is faster than `mapstructure.Decode` or any of the other
// common reflection-based approaches when using real-world medium-sized
// JSON payloads with lots of strings.
f := v.Field(inputBodyIndex)
if err := api.Unmarshal(ctx.Header("Content-Type"), body, f.Addr().Interface()); err != nil {
if parseErrCount == 0 {
// Hmm, this should have worked... validator missed something?
res.Errors = append(res.Errors, &ErrorDetail{
Location: "body",
Message: err.Error(),
Value: string(body),
if inputBodyIndex != -1 {
// We need to get the body into the correct type now that it has been
// validated. Benchmarks on Go 1.20 show that using `json.Unmarshal` a
// second time is faster than `mapstructure.Decode` or any of the other
// common reflection-based approaches when using real-world medium-sized
// JSON payloads with lots of strings.
f := v.Field(inputBodyIndex)
if err := api.Unmarshal(ctx.Header("Content-Type"), body, f.Addr().Interface()); err != nil {
if parseErrCount == 0 {
// Hmm, this should have worked... validator missed something?
res.Errors = append(res.Errors, &ErrorDetail{
Location: "body",
Message: err.Error(),
Value: string(body),
})
}
} else {
// Set defaults for any fields that were not in the input.
defaults.Every(v, func(item reflect.Value, def any) {
if item.IsZero() {
item.Set(reflect.Indirect(reflect.ValueOf(def)))
}
})
}
} else {
// Set defaults for any fields that were not in the input.
defaults.Every(v, func(item reflect.Value, def any) {
if item.IsZero() {
item.Set(reflect.Indirect(reflect.ValueOf(def)))
}
})
}

buf.Reset()
Expand Down
58 changes: 58 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,64 @@ func TestFeatures(t *testing.T) {
assert.Equal(t, 256, resp.Code)
},
},
{
Name: "one-of input",
Register: func(t *testing.T, api API) {
// Step 1: create a custom schema
customSchema := &Schema{
OneOf: []*Schema{
{
Type: TypeObject,
Properties: map[string]*Schema{
"foo": {Type: TypeString},
},
},
{
Type: TypeArray,
Items: &Schema{
Type: TypeObject,
Properties: map[string]*Schema{
"foo": {Type: TypeString},
},
},
},
},
}
customSchema.PrecomputeMessages()

Register(api, Operation{
Method: http.MethodPut,
Path: "/one-of",
// Step 2: register an operation with a custom schema
RequestBody: &RequestBody{
Required: true,
Content: map[string]*MediaType{
"application/json": {
Schema: customSchema,
},
},
},
}, func(ctx context.Context, input *struct {
// Step 3: only take in raw bytes
RawBody []byte
}) (*struct{}, error) {
// Step 4: determine which it is and parse it into the right type.
// We will check the first byte but there are other ways to do this.
assert.EqualValues(t, '[', input.RawBody[0])
var parsed []struct {
Foo string `json:"foo"`
}
assert.NoError(t, json.Unmarshal(input.RawBody, &parsed))
assert.Len(t, parsed, 2)
assert.Equal(t, "first", parsed[0].Foo)
assert.Equal(t, "second", parsed[1].Foo)
return nil, nil
})
},
Method: http.MethodPut,
URL: "/one-of",
Body: `[{"foo": "first"}, {"foo": "second"}]`,
},
} {
t.Run(feature.Name, func(t *testing.T) {
r := chi.NewRouter()
Expand Down
21 changes: 21 additions & 0 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ type Schema struct {
Deprecated bool `yaml:"deprecated,omitempty"`
Extensions map[string]any `yaml:",inline"`

OneOf []*Schema `yaml:"oneOf,omitempty"`
AnyOf []*Schema `yaml:"anyOf,omitempty"`
AllOf []*Schema `yaml:"allOf,omitempty"`
Not *Schema `yaml:"not,omitempty"`

patternRe *regexp.Regexp `yaml:"-"`
requiredMap map[string]bool `yaml:"-"`
propertyNames []string `yaml:"-"`
Expand Down Expand Up @@ -162,6 +167,22 @@ func (s *Schema) PrecomputeMessages() {
s.msgRequired[name] = "expected required property " + name + " to be present"
}
}

for _, sub := range s.OneOf {
sub.PrecomputeMessages()
}

for _, sub := range s.AnyOf {
sub.PrecomputeMessages()
}

for _, sub := range s.AllOf {
sub.PrecomputeMessages()
}

if sub := s.Not; sub != nil {
sub.PrecomputeMessages()
}
}

// MarshalJSON marshals the schema into JSON, respecting the `Extensions` map
Expand Down
56 changes: 56 additions & 0 deletions validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,40 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult
}
}

func validateOneOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, res *ValidateResult) {
found := false
subRes := &ValidateResult{}
for _, sub := range s.OneOf {
Validate(r, sub, path, mode, v, subRes)
if len(subRes.Errors) == 0 {
if found {
res.Add(path, v, "expected value to match exactly one schema but matched multiple")
}
found = true
}
subRes.Reset()
}
if !found {
res.Add(path, v, "expected value to match exactly one schema but matched none")
}
}

func validateAnyOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, res *ValidateResult) {
matches := 0
subRes := &ValidateResult{}
for _, sub := range s.AnyOf {
Validate(r, sub, path, mode, v, subRes)
if len(subRes.Errors) == 0 {
matches++
}
subRes.Reset()
}

if matches == 0 {
res.Add(path, v, "expected value to match at least one schema but matched none")
}
}

// Validate an input value against a schema, collecting errors in the validation
// result object. If successful, `res.Errors` will be empty. It is suggested
// to use a `sync.Pool` to reuse the PathBuffer and ValidateResult objects,
Expand All @@ -284,6 +318,28 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
s = r.SchemaFromRef(s.Ref)
}

if s.OneOf != nil {
validateOneOf(r, s, path, mode, v, res)
}

if s.AnyOf != nil {
validateAnyOf(r, s, path, mode, v, res)
}

if s.AllOf != nil {
for _, sub := range s.AllOf {
Validate(r, sub, path, mode, v, res)
}
}

if s.Not != nil {
subRes := &ValidateResult{}
Validate(r, s.Not, path, mode, v, subRes)
if len(subRes.Errors) == 0 {
res.Add(path, v, "expected value to not match schema")
}
}

switch s.Type {
case TypeBoolean:
if _, ok := v.(bool); !ok {
Expand Down
Loading

0 comments on commit c81c4ba

Please sign in to comment.