Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: handling files from multipart/form-data request #415

Merged
merged 15 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 252 additions & 0 deletions formdata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
package huma

import (
"fmt"
"io"
"mime/multipart"
"net/http"
"reflect"
"slices"
"strings"
)

type FormFile struct {
multipart.File
ContentType string // Content-Type as declared in the multipart form field, or detected when parsing request as fallback
IsSet bool // Indicates whether content was received when working with optional files
}

type MultipartFormFiles[T any] struct {
Form *multipart.Form
data *T
}

type MimeTypeValidator struct {
accept []string
}

func NewMimeTypeValidator(encoding *Encoding) MimeTypeValidator {
var mimeTypes = strings.Split(encoding.ContentType, ",")
for i := range mimeTypes {
mimeTypes[i] = strings.Trim(mimeTypes[i], " ")
}
if len(mimeTypes) == 0 {
mimeTypes = []string{"application/octet-stream"}

Check warning on line 34 in formdata.go

View check run for this annotation

Codecov / codecov/patch

formdata.go#L34

Added line #L34 was not covered by tests
}
return MimeTypeValidator{accept: mimeTypes}
Comment on lines +28 to +36
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider handling empty ContentType more robustly.

The current implementation defaults to "application/octet-stream" only when no content types are specified. It might be beneficial to handle cases where the content type is provided but empty, ensuring robustness.

- if len(mimeTypes) == 0 {
+ if len(mimeTypes) == 0 || (len(mimeTypes) == 1 && mimeTypes[0] == "") {
    mimeTypes = []string{"application/octet-stream"}
  }

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
func NewMimeTypeValidator(encoding *Encoding) MimeTypeValidator {
var mimeTypes = strings.Split(encoding.ContentType, ",")
for i := range mimeTypes {
mimeTypes[i] = strings.Trim(mimeTypes[i], " ")
}
if len(mimeTypes) == 0 {
mimeTypes = []string{"application/octet-stream"}
}
return MimeTypeValidator{accept: mimeTypes}
func NewMimeTypeValidator(encoding *Encoding) MimeTypeValidator {
var mimeTypes = strings.Split(encoding.ContentType, ",")
for i := range mimeTypes {
mimeTypes[i] = strings.Trim(mimeTypes[i], " ")
}
if len(mimeTypes) == 0 || (len(mimeTypes) == 1 && mimeTypes[0] == "") {
mimeTypes = []string{"application/octet-stream"}
}
return MimeTypeValidator{accept: mimeTypes}

}

// Validate checks the mime type of the provided file against the expected content type.
// In the absence of a Content-Type file header, the mime type is detected using [http.DetectContentType].
func (v MimeTypeValidator) Validate(fh *multipart.FileHeader, location string) (string, *ErrorDetail) {
file, err := fh.Open()
if err != nil {
return "", &ErrorDetail{Message: "Failed to open file", Location: location}

Check warning on line 44 in formdata.go

View check run for this annotation

Codecov / codecov/patch

formdata.go#L44

Added line #L44 was not covered by tests
}

mimeType := fh.Header.Get("Content-Type")
if mimeType == "" {
var buffer = make([]byte, 1000)
if _, err := file.Read(buffer); err != nil {
return "", &ErrorDetail{Message: "Failed to infer file media type", Location: location}

Check warning on line 51 in formdata.go

View check run for this annotation

Codecov / codecov/patch

formdata.go#L51

Added line #L51 was not covered by tests
}
file.Seek(int64(0), io.SeekStart)
mimeType = http.DetectContentType(buffer)
}
accept := slices.ContainsFunc(v.accept, func(m string) bool {
if m == "text/plain" || m == "application/octet-stream" {
return true
}
if strings.HasSuffix(m, "/*") &&
strings.HasPrefix(mimeType, strings.TrimRight(m, "*")) {
return true
}
if mimeType == m {
return true
}
return false
})

if accept {
return mimeType, nil
} else {
return mimeType, &ErrorDetail{
Message: fmt.Sprintf(
"Invalid mime type: got %v, expected %v",
mimeType, strings.Join(v.accept, ","),
),
Location: location,
Value: mimeType,
}
}
danielgtaylor marked this conversation as resolved.
Show resolved Hide resolved
}

func (m *MultipartFormFiles[T]) readFile(
fh *multipart.FileHeader,
location string,
validator MimeTypeValidator,
) (FormFile, *ErrorDetail) {
f, err := fh.Open()
if err != nil {
return FormFile{}, &ErrorDetail{Message: "Failed to open file", Location: location}

Check warning on line 91 in formdata.go

View check run for this annotation

Codecov / codecov/patch

formdata.go#L91

Added line #L91 was not covered by tests
}
contentType, validationErr := validator.Validate(fh, location)
if validationErr != nil {
return FormFile{}, validationErr
}
return FormFile{File: f, ContentType: contentType, IsSet: true}, nil
}

func (m *MultipartFormFiles[T]) readSingleFile(key string, opMediaType *MediaType) (FormFile, *ErrorDetail) {
fileHeaders := m.Form.File[key]
if len(fileHeaders) == 0 {
if opMediaType.Schema.requiredMap[key] {
return FormFile{}, &ErrorDetail{Message: "File required", Location: key}
} else {
return FormFile{}, nil
}
} else if len(fileHeaders) == 1 {
validator := NewMimeTypeValidator(opMediaType.Encoding[key])
return m.readFile(fileHeaders[0], key, validator)
}
return FormFile{}, &ErrorDetail{
Message: "Multiple files received but only one was expected",
Location: key,
}
}
danielgtaylor marked this conversation as resolved.
Show resolved Hide resolved

func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *MediaType) ([]FormFile, []error) {
fileHeaders := m.Form.File[key]
var (
files = make([]FormFile, len(fileHeaders))
errors []error
)
if opMediaType.Schema.requiredMap[key] && len(fileHeaders) == 0 {
return nil, []error{&ErrorDetail{Message: "At least one file is required", Location: key}}
}
validator := NewMimeTypeValidator(opMediaType.Encoding[key])
for i, fh := range fileHeaders {
file, err := m.readFile(
fh,
fmt.Sprintf("%s[%d]", key, i),
validator,
)
if err != nil {
errors = append(errors, err)
continue
}
files[i] = file
}
return files, errors
}

func (m *MultipartFormFiles[T]) Data() *T {
return m.data
}

// Decodes multipart.Form data into *T, returning []*ErrorDetail if any
// Schema is used to check for validation constraints
func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType) []error {
var (
dataType = reflect.TypeOf(m.data).Elem()
value = reflect.New(dataType)
errors []error
)
for i := 0; i < dataType.NumField(); i++ {
field := value.Elem().Field(i)
structField := dataType.Field(i)
key := structField.Tag.Get("form")
if key == "" {
key = structField.Name
}
switch {
case field.Type() == reflect.TypeOf(FormFile{}):
file, err := m.readSingleFile(key, opMediaType)
if err != nil {
errors = append(errors, err)
continue
}
field.Set(reflect.ValueOf(file))
case field.Type() == reflect.TypeOf([]FormFile{}):
files, errs := m.readMultipleFiles(key, opMediaType)
if errs != nil {
errors = slices.Concat(errors, errs)
continue
}
field.Set(reflect.ValueOf(files))

default:
continue
}
}
m.data = value.Interface().(*T)
return errors
}
danielgtaylor marked this conversation as resolved.
Show resolved Hide resolved

func formDataFieldName(f reflect.StructField) string {
name := f.Name
if formDataKey := f.Tag.Get("form"); formDataKey != "" {
name = formDataKey
}
return name
}

func multiPartFormFileSchema(t reflect.Type) *Schema {
nFields := t.NumField()
schema := &Schema{
Type: "object",
Properties: make(map[string]*Schema, nFields),
requiredMap: make(map[string]bool, nFields),
}
requiredFields := make([]string, nFields)
for i := 0; i < nFields; i++ {
f := t.Field(i)
name := formDataFieldName(f)

switch {
case f.Type == reflect.TypeOf(FormFile{}):
schema.Properties[name] = multiPartFileSchema(f)
case f.Type == reflect.TypeOf([]FormFile{}):
schema.Properties[name] = &Schema{
Type: "array",
Items: multiPartFileSchema(f),
}
default:
// Should we panic if [T] struct defines fields with unsupported types ?
continue
}

if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required") {
requiredFields[i] = name
schema.requiredMap[name] = true
}
}
schema.Required = requiredFields
return schema
}

func multiPartFileSchema(f reflect.StructField) *Schema {
return &Schema{
Type: "string",
Format: "binary",
Description: f.Tag.Get("doc"),
ContentEncoding: "binary",
}
}

func multiPartContentEncoding(t reflect.Type) map[string]*Encoding {
nFields := t.NumField()
encoding := make(map[string]*Encoding, nFields)
for i := 0; i < nFields; i++ {
f := t.Field(i)
name := formDataFieldName(f)
contentType := f.Tag.Get("contentType")
if contentType == "" {
contentType = "application/octet-stream"
}
encoding[name] = &Encoding{
ContentType: contentType,
}
}
return encoding
}
65 changes: 50 additions & 15 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@
}
rawBodyIndex := -1
rawBodyMultipart := false
rawBodyDecodedMultipart := false
danielgtaylor marked this conversation as resolved.
Show resolved Hide resolved
if f, ok := inputType.FieldByName("RawBody"); ok {
rawBodyIndex = f.Index[0]
if op.RequestBody == nil {
Expand All @@ -625,28 +626,48 @@
contentType = "multipart/form-data"
rawBodyMultipart = true
}
if strings.HasPrefix(f.Type.Name(), "MultipartFormFiles") {
contentType = "multipart/form-data"
rawBodyDecodedMultipart = true
}

if c := f.Tag.Get("contentType"); c != "" {
contentType = c
}

switch contentType {
case "multipart/form-data":
op.RequestBody.Content["multipart/form-data"] = &MediaType{
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"name": {
Type: "string",
Description: "general purpose name for multipart form value",
},
"filename": {
Type: "string",
Format: "binary",
Description: "filename of the file being uploaded",
if op.RequestBody.Content["multipart/form-data"] != nil {
break

Check warning on line 641 in huma.go

View check run for this annotation

Codecov / codecov/patch

huma.go#L641

Added line #L641 was not covered by tests
}
if rawBodyMultipart {
op.RequestBody.Content["multipart/form-data"] = &MediaType{
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"name": {
Type: "string",
Description: "general purpose name for multipart form value",
},
"filename": {
Type: "string",
Format: "binary",
Description: "filename of the file being uploaded",
},
},
},
},
}
}
if rawBodyDecodedMultipart {
dataField, ok := f.Type.FieldByName("data")
if !ok {
panic("Expected type MultipartFormFiles[T] to have a 'data *T' generic pointer field")

Check warning on line 664 in huma.go

View check run for this annotation

Codecov / codecov/patch

huma.go#L664

Added line #L664 was not covered by tests
}
op.RequestBody.Content["multipart/form-data"] = &MediaType{
Schema: multiPartFormFileSchema(dataField.Type.Elem()),
Encoding: multiPartContentEncoding(dataField.Type.Elem()),
}
op.RequestBody.Required = false
}
default:
op.RequestBody.Content[contentType] = &MediaType{
Expand Down Expand Up @@ -1132,7 +1153,7 @@
ctx.SetReadDeadline(time.Time{})
}

if rawBodyMultipart {
if rawBodyMultipart || rawBodyDecodedMultipart {
form, err := ctx.GetMultipartForm()
if err != nil || form == nil {
res.Errors = append(res.Errors, &ErrorDetail{
Expand All @@ -1141,7 +1162,21 @@
})
} else {
f := v.Field(rawBodyIndex)
f.Set(reflect.ValueOf(*form))
if rawBodyMultipart {
f.Set(reflect.ValueOf(*form))
} else {
f.FieldByName("Form").Set(reflect.ValueOf(form))
r := f.Addr().
MethodByName("Decode").
Call([]reflect.Value{
reflect.ValueOf(op.RequestBody.Content["multipart/form-data"]),
})
errs := r[0].Interface().([]error)
if errs != nil {
WriteErr(api, ctx, http.StatusUnprocessableEntity, "validation failed", errs...)
return
}
}
Comment on lines +1165 to +1179
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider refactoring this block to reduce complexity and improve readability.

- if rawBodyMultipart {
-     f.Set(reflect.ValueOf(*form))
- } else {
-     f.FieldByName("Form").Set(reflect.ValueOf(form))
-     r := f.Addr().
-         MethodByName("Decode").
-         Call([]reflect.Value{
-             reflect.ValueOf(op.RequestBody.Content["multipart/form-data"]),
-         })
-     errs := r[0].Interface().([]error)
-     if errs != nil {
-         WriteErr(api, ctx, http.StatusUnprocessableEntity, "validation failed", errs...)
-         return
-     }
- }
+ handleMultipartForm(f, form, rawBodyMultipart, op, api, ctx)

Consider extracting this logic into a separate function handleMultipartForm to handle the multipart form based on the rawBodyMultipart flag.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if rawBodyMultipart {
f.Set(reflect.ValueOf(*form))
} else {
f.FieldByName("Form").Set(reflect.ValueOf(form))
r := f.Addr().
MethodByName("Decode").
Call([]reflect.Value{
reflect.ValueOf(op.RequestBody.Content["multipart/form-data"]),
})
errs := r[0].Interface().([]error)
if errs != nil {
WriteErr(api, ctx, http.StatusUnprocessableEntity, "validation failed", errs...)
return
}
}
if rawBodyMultipart {
f.Set(reflect.ValueOf(*form))
} else {
f.FieldByName("Form").Set(reflect.ValueOf(form))
r := f.Addr().
MethodByName("Decode").
Call([]reflect.Value{
reflect.ValueOf(op.RequestBody.Content["multipart/form-data"]),
})
errs := r[0].Interface().([]error)
if errs != nil {
WriteErr(api, ctx, http.StatusUnprocessableEntity, "validation failed", errs...)
return
}
}
```
</details>
<!-- suggestion_end -->
<!-- This is an auto-generated comment by CodeRabbit -->

}
} else {
buf := bufPool.Get().(*bytes.Buffer)
Expand Down
Loading
Loading