Skip to content

Commit

Permalink
Merge pull request #118 from danielgtaylor/time-fix
Browse files Browse the repository at this point in the history
fix: param date/time parsing, better tests/coverage
  • Loading branch information
danielgtaylor authored Sep 1, 2023
2 parents 14b7ec8 + 26b47ed commit ed0a877
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 12 deletions.
1 change: 1 addition & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ignore:
- benchmark
- examples
- adapters
19 changes: 7 additions & 12 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
pfi.Name = name

if f.Type == timeType {
timeFormat := http.TimeFormat
timeFormat := time.RFC3339Nano
if pfi.Loc == "header" {
timeFormat = http.TimeFormat
}
if f := f.Tag.Get("timeFormat"); f != "" {
timeFormat = f
}
Expand Down Expand Up @@ -586,21 +589,13 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)

// Special case: time.Time
if f.Type() == timeType {
timeFormat := time.RFC3339Nano
if p.Loc == "header" {
timeFormat = http.TimeFormat
}
if p.TimeFormat != "" {
timeFormat = p.TimeFormat
}

t, err := time.Parse(timeFormat, value)
t, err := time.Parse(p.TimeFormat, value)
if err != nil {
res.Add(pb, value, "invalid time")
res.Add(pb, value, "invalid date/time for format "+p.TimeFormat)
return
}
f.Set(reflect.ValueOf(t))
pv = t
pv = value
break
}

Expand Down
93 changes: 93 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/danielgtaylor/huma/v2/queryparam"
"github.com/go-chi/chi"
"github.com/goccy/go-yaml"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -115,6 +116,98 @@ func NewTestAdapter(r chi.Router, config Config) API {
return NewAPI(config, &testAdapter{router: r})
}

func TestFeatures(t *testing.T) {
for _, feature := range []struct {
Name string
Register func(t *testing.T, api API)
Method string
URL string
Headers map[string]string
Assert func(t *testing.T, resp *httptest.ResponseRecorder)
}{
{
Name: "params",
Register: func(t *testing.T, api API) {
Register(api, Operation{
Method: http.MethodGet,
Path: "/test-params/{string}/{int}",
}, func(ctx context.Context, input *struct {
PathString string `path:"string"`
PathInt int `path:"int"`
QueryString string `query:"string"`
QueryInt int `query:"int"`
QueryDefault float32 `query:"def" default:"135" example:"5"`
QueryBefore time.Time `query:"before"`
QueryDate time.Time `query:"date" timeFormat:"2006-01-02"`
HeaderString string `header:"String"`
HeaderInt int `header:"Int"`
}) (*struct{}, error) {
assert.Equal(t, "foo", input.PathString)
assert.Equal(t, 123, input.PathInt)
assert.Equal(t, "bar", input.QueryString)
assert.Equal(t, 456, input.QueryInt)
assert.Equal(t, float32(135), input.QueryDefault)
assert.True(t, input.QueryBefore.Equal(time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC)))
assert.True(t, input.QueryDate.Equal(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)))
assert.Equal(t, "baz", input.HeaderString)
assert.Equal(t, 789, input.HeaderInt)
return nil, nil
})
},
Method: http.MethodGet,
URL: "/test-params/foo/123?string=bar&int=456&before=2023-01-01T12:00:00Z&date=2023-01-01",
Headers: map[string]string{"string": "baz", "int": "789"},
},
{
Name: "response",
Register: func(t *testing.T, api API) {
type Resp struct {
Foo string `header:"foo"`
Body struct {
Greeting string `json:"greeting"`
}
}

Register(api, Operation{
Method: http.MethodGet,
Path: "/response",
}, func(ctx context.Context, input *struct{}) (*Resp, error) {
resp := &Resp{}
resp.Foo = "foo"
resp.Body.Greeting = "Hello, world!"
return resp, nil
})
},
Method: http.MethodGet,
URL: "/response",
Assert: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "foo", resp.Header().Get("Foo"))
assert.JSONEq(t, `{"$schema": "https:///schemas/RespBody.json", "greeting":"Hello, world!"}`, resp.Body.String())
},
},
} {
t.Run(feature.Name, func(t *testing.T) {
r := chi.NewRouter()
api := NewTestAdapter(r, DefaultConfig("Features Test API", "1.0.0"))
feature.Register(t, api)

req, _ := http.NewRequest(feature.Method, feature.URL, nil)
for k, v := range feature.Headers {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
b, _ := yaml.Marshal(api.OpenAPI())
t.Log(string(b))
assert.Less(t, w.Code, 300, w.Body.String())
if feature.Assert != nil {
feature.Assert(t, w)
}
})
}
}

type ExhaustiveErrorsInputBody struct {
Name string `json:"name" maxLength:"10"`
Count int `json:"count" minimum:"1"`
Expand Down
10 changes: 10 additions & 0 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,16 @@ func SchemaFromField(registry Registry, parent reflect.Type, f reflect.StructFie
if fmt := f.Tag.Get("format"); fmt != "" {
fs.Format = fmt
}
if timeFmt := f.Tag.Get("timeFormat"); timeFmt != "" {
switch timeFmt {
case "2006-01-02":
fs.Format = "date"
case "15:04:05":
fs.Format = "time"
default:
fs.Format = timeFmt
}
}
if enc := f.Tag.Get("encoding"); enc != "" {
fs.ContentEncoding = enc
}
Expand Down

0 comments on commit ed0a877

Please sign in to comment.