Skip to content

Commit

Permalink
Merge pull request #498 from csmarchbanks/exploded-query-parameters
Browse files Browse the repository at this point in the history
Implement exploded query parameters
  • Loading branch information
danielgtaylor authored Jul 15, 2024
2 parents b8be41a + 6961c0c commit 6ff17c6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/docs/features/request-inputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The following parameter types are supported out of the box:
| `time.Time` | `2020-01-01T12:00:00Z` |
| slice, e.g. `[]int` | `1,2,3`, `tag1,tag2` |

For example, if the parameter is a query param and the type is `[]string` it might look like `?tags=tag1,tag2` in the URI.
For example, if the parameter is a query param and the type is `[]string` it might look like `?tags=tag1,tag2` in the URI. Query paramaters also support specifying the same parameter multiple times by setting the `explode` tag, e.g. `query:"tags,explode"` would parse a query string like `?tags=tag1&tags=tag2` instead of a comma separated list. The comma separated list is faster and recommended for most use cases.

For cookies, the default behavior is to read the cookie _value_ from the request and convert it to one of the types above. If you want to access the entire cookie, you can use `http.Cookie` as the type instead:

Expand Down
31 changes: 15 additions & 16 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ type paramFieldInfo struct {
Required bool
Default string
TimeFormat string
Explode bool
Schema *Schema
}

Expand Down Expand Up @@ -136,11 +137,14 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
pfi.Required = true
} else if q := f.Tag.Get("query"); q != "" {
pfi.Loc = "query"
name = q
split := strings.Split(q, ",")
name = split[0]
// If `in` is `query` then `explode` defaults to true. Parsing is *much*
// easier if we use comma-separated values, so we disable explode.
nope := false
explode = &nope
// easier if we use comma-separated values, so we disable explode by default.
if slicesContains(split[1:], "explode") {
pfi.Explode = true
}
explode = &pfi.Explode
} else if h := f.Tag.Get("header"); h != "" {
pfi.Loc = "header"
name = h
Expand Down Expand Up @@ -933,15 +937,20 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = v
default:
if f.Type().Kind() == reflect.Slice {
var values []string
if p.Explode {
u := ctx.URL()
values = (&u).Query()[p.Name]
} else {
values = strings.Split(value, ",")
}
switch f.Type().Elem().Kind() {

case reflect.String:
values := strings.Split(value, ",")
f.Set(reflect.ValueOf(values))
pv = values

case reflect.Int:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (int, error) {
val, err := strconv.ParseInt(s, 10, strconv.IntSize)
if err != nil {
Expand All @@ -957,7 +966,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Int8:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (int8, error) {
val, err := strconv.ParseInt(s, 10, 8)
if err != nil {
Expand All @@ -973,7 +981,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Int16:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (int16, error) {
val, err := strconv.ParseInt(s, 10, 16)
if err != nil {
Expand All @@ -989,7 +996,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Int32:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (int32, error) {
val, err := strconv.ParseInt(s, 10, 32)
if err != nil {
Expand All @@ -1005,7 +1011,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Int64:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (int64, error) {
val, err := strconv.ParseInt(s, 10, 64)
if err != nil {
Expand All @@ -1021,7 +1026,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Uint:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (uint, error) {
val, err := strconv.ParseUint(s, 10, strconv.IntSize)
if err != nil {
Expand All @@ -1037,7 +1041,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Uint16:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (uint16, error) {
val, err := strconv.ParseUint(s, 10, 16)
if err != nil {
Expand All @@ -1053,7 +1056,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Uint32:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (uint32, error) {
val, err := strconv.ParseUint(s, 10, 32)
if err != nil {
Expand All @@ -1069,7 +1071,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Uint64:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (uint64, error) {
val, err := strconv.ParseUint(s, 10, 64)
if err != nil {
Expand All @@ -1085,7 +1086,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Float32:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (float32, error) {
val, err := strconv.ParseFloat(s, 32)
if err != nil {
Expand All @@ -1101,7 +1101,6 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
pv = vs

case reflect.Float64:
values := strings.Split(value, ",")
vs, err := parseArrElement(values, func(s string) (float64, error) {
val, err := strconv.ParseFloat(s, 64)
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ func TestFeatures(t *testing.T) {
QueryUints64 []uint64 `query:"uints64"`
QueryFloats32 []float32 `query:"floats32"`
QueryFloats64 []float64 `query:"floats64"`
QueryExploded []string `query:"exploded,explode"`
HeaderString string `header:"String"`
HeaderInt int `header:"Int"`
HeaderDate time.Time `header:"Date"`
Expand Down Expand Up @@ -401,17 +402,18 @@ func TestFeatures(t *testing.T) {
assert.Equal(t, "foo", input.CookieValue)
assert.Equal(t, 123, input.CookieInt)
assert.Equal(t, "bar", input.CookieFull.Value)
assert.Equal(t, []string{"foo", "bar"}, input.QueryExploded)
return nil, nil
})

// Docs should be available on the param object, not just the schema.
assert.Equal(t, "Some docs", api.OpenAPI().Paths["/test-params/{string}/{int}/{uuid}"].Get.Parameters[0].Description)

// `http.Cookie` should be treated as a string.
assert.Equal(t, "string", api.OpenAPI().Paths["/test-params/{string}/{int}/{uuid}"].Get.Parameters[27].Schema.Type)
assert.Equal(t, "string", api.OpenAPI().Paths["/test-params/{string}/{int}/{uuid}"].Get.Parameters[28].Schema.Type)
},
Method: http.MethodGet,
URL: "/test-params/foo/123/fba4f46b-4539-4d19-8e3f-a0e629a243b5?string=bar&int=456&before=2023-01-01T12:00:00Z&date=2023-01-01&uint=1&bool=true&strings=foo,bar&ints=2,3&ints8=4,5&ints16=4,5&ints32=4,5&ints64=4,5&uints=1,2&uints16=10,15&uints32=10,15&uints64=10,15&floats32=2.2,2.3&floats64=3.2,3.3",
URL: "/test-params/foo/123/fba4f46b-4539-4d19-8e3f-a0e629a243b5?string=bar&int=456&before=2023-01-01T12:00:00Z&date=2023-01-01&uint=1&bool=true&strings=foo,bar&ints=2,3&ints8=4,5&ints16=4,5&ints32=4,5&ints64=4,5&uints=1,2&uints16=10,15&uints32=10,15&uints64=10,15&floats32=2.2,2.3&floats64=3.2,3.3&exploded=foo&exploded=bar",
Headers: map[string]string{
"string": "baz",
"int": "789",
Expand Down

0 comments on commit 6ff17c6

Please sign in to comment.