Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make validation error messages customizable #520

Merged
merged 5 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"strconv"
"strings"
"time"

"github.com/danielgtaylor/huma/v2/validation"
)

// ErrSchemaInvalid is sent when there is a problem building the schema.
Expand Down Expand Up @@ -202,57 +204,57 @@ func (s *Schema) MarshalJSON() ([]byte, error) {
// PrecomputeMessages tries to precompute as many validation error messages
// as possible so that new strings aren't allocated during request validation.
func (s *Schema) PrecomputeMessages() {
s.msgEnum = "expected value to be one of \"" + strings.Join(mapTo(s.Enum, func(v any) string {
s.msgEnum = validation.ErrorFormatter(validation.MsgExpectedOneOf, strings.Join(mapTo(s.Enum, func(v any) string {
return fmt.Sprintf("%v", v)
}), ", ") + "\""
}), ", "))
if s.Minimum != nil {
s.msgMinimum = fmt.Sprintf("expected number >= %v", *s.Minimum)
s.msgMinimum = validation.ErrorFormatter(validation.MsgExpectedMinimumNumber, *s.Minimum)
}
if s.ExclusiveMinimum != nil {
s.msgExclusiveMinimum = fmt.Sprintf("expected number > %v", *s.ExclusiveMinimum)
s.msgExclusiveMinimum = validation.ErrorFormatter(validation.MsgExpectedExclusiveMinimumNumber, *s.ExclusiveMinimum)
}
if s.Maximum != nil {
s.msgMaximum = fmt.Sprintf("expected number <= %v", *s.Maximum)
s.msgMaximum = validation.ErrorFormatter(validation.MsgExpectedMaximumNumber, *s.Maximum)
}
if s.ExclusiveMaximum != nil {
s.msgExclusiveMaximum = fmt.Sprintf("expected number < %v", *s.ExclusiveMaximum)
s.msgExclusiveMaximum = validation.ErrorFormatter(validation.MsgExpectedExclusiveMaximumNumber, *s.ExclusiveMaximum)
}
if s.MultipleOf != nil {
s.msgMultipleOf = fmt.Sprintf("expected number to be a multiple of %v", *s.MultipleOf)
s.msgMultipleOf = validation.ErrorFormatter(validation.MsgExpectedNumberBeMultipleOf, *s.MultipleOf)
}
if s.MinLength != nil {
s.msgMinLength = fmt.Sprintf("expected length >= %d", *s.MinLength)
s.msgMinLength = validation.ErrorFormatter(validation.MsgExpectedMinLength, *s.MinLength)
}
if s.MaxLength != nil {
s.msgMaxLength = fmt.Sprintf("expected length <= %d", *s.MaxLength)
s.msgMaxLength = validation.ErrorFormatter(validation.MsgExpectedMaxLength, *s.MaxLength)
}
if s.Pattern != "" {
s.patternRe = regexp.MustCompile(s.Pattern)
if s.PatternDescription != "" {
s.msgPattern = "expected string to be " + s.PatternDescription
s.msgPattern = validation.ErrorFormatter(validation.MsgExpectedBePattern, s.PatternDescription)
} else {
s.msgPattern = "expected string to match pattern " + s.Pattern
s.msgPattern = validation.ErrorFormatter(validation.MsgExpectedMatchPattern, s.Pattern)
}
}
if s.MinItems != nil {
s.msgMinItems = fmt.Sprintf("expected array length >= %d", *s.MinItems)
s.msgMinItems = validation.ErrorFormatter(validation.MsgExpectedMinItems, *s.MinItems)
}
if s.MaxItems != nil {
s.msgMaxItems = fmt.Sprintf("expected array length <= %d", *s.MaxItems)
s.msgMaxItems = validation.ErrorFormatter(validation.MsgExpectedMaxItems, *s.MaxItems)
}
if s.MinProperties != nil {
s.msgMinProperties = fmt.Sprintf("expected object with at least %d properties", *s.MinProperties)
s.msgMinProperties = validation.ErrorFormatter(validation.MsgExpectedMinProperties, *s.MinProperties)
}
if s.MaxProperties != nil {
s.msgMaxProperties = fmt.Sprintf("expected object with at most %d properties", *s.MaxProperties)
s.msgMaxProperties = validation.ErrorFormatter(validation.MsgExpectedMaxProperties, *s.MaxProperties)
}

if s.Required != nil {
if s.msgRequired == nil {
s.msgRequired = map[string]string{}
}
for _, name := range s.Required {
s.msgRequired[name] = "expected required property " + name + " to be present"
s.msgRequired[name] = validation.ErrorFormatter(validation.MsgExpectedRequiredProperty, name)
}
}

Expand All @@ -265,7 +267,7 @@ func (s *Schema) PrecomputeMessages() {
if s.msgDependentRequired[name] == nil {
s.msgDependentRequired[name] = map[string]string{}
}
s.msgDependentRequired[name][dependent] = "expected property " + dependent + " to be present when " + name + " is present"
s.msgDependentRequired[name][dependent] = validation.ErrorFormatter(validation.MsgExpectedDependentRequiredProperty, dependent, name)
}
}
}
Expand Down
78 changes: 35 additions & 43 deletions validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"time"
"unicode/utf8"
"unsafe"

"github.com/danielgtaylor/huma/v2/validation"
)

// ValidateMode describes the direction of validation (server -> client or
Expand Down Expand Up @@ -166,16 +168,6 @@ func (r *ValidateResult) Add(path *PathBuffer, v any, msg string) {
})
}

// Addf adds an error to the validation result at the given path and with
// the given value, allowing for fmt.Printf-style formatting.
func (r *ValidateResult) Addf(path *PathBuffer, v any, format string, args ...any) {
smacker marked this conversation as resolved.
Show resolved Hide resolved
r.Errors = append(r.Errors, &ErrorDetail{
Message: fmt.Sprintf(format, args...),
Location: path.String(),
Value: v,
})
}

// Reset the validation error so it can be used again.
func (r *ValidateResult) Reset() {
r.Errors = r.Errors[:0]
Expand All @@ -192,69 +184,69 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult
}
}
if !found {
res.Add(path, str, "expected string to be RFC 3339 date-time")
res.Add(path, str, validation.MsgExpectedRFC3339DateTime)
}
case "date-time-http":
if _, err := time.Parse(time.RFC1123, str); err != nil {
res.Add(path, str, "expected string to be RFC 1123 date-time")
res.Add(path, str, validation.MsgExpectedRFC1123DateTime)
}
case "date":
if _, err := time.Parse("2006-01-02", str); err != nil {
res.Add(path, str, "expected string to be RFC 3339 date")
res.Add(path, str, validation.MsgExpectedRFC3339Date)
}
case "time":
if _, err := time.Parse("15:04:05", str); err != nil {
if _, err := time.Parse("15:04:05Z07:00", str); err != nil {
res.Add(path, str, "expected string to be RFC 3339 time")
res.Add(path, str, validation.MsgExpectedRFC3339Time)
}
}
// TODO: duration
case "email", "idn-email":
if _, err := mail.ParseAddress(str); err != nil {
res.Addf(path, str, "expected string to be RFC 5322 email: %v", err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC5322Email, err))
}
case "hostname":
if !(rxHostname.MatchString(str) && len(str) < 256) {
res.Add(path, str, "expected string to be RFC 5890 hostname")
res.Add(path, str, validation.MsgExpectedRFC5890Hostname)
}
// TODO: proper idn-hostname support... need to figure out how.
case "ipv4":
if ip := net.ParseIP(str); ip == nil || ip.To4() == nil {
res.Add(path, str, "expected string to be RFC 2673 ipv4")
res.Add(path, str, validation.MsgExpectedRFC2673IPv4)
}
case "ipv6":
if ip := net.ParseIP(str); ip == nil || ip.To16() == nil {
res.Add(path, str, "expected string to be RFC 2373 ipv6")
res.Add(path, str, validation.MsgExpectedRFC2373IPv6)
}
case "uri", "uri-reference", "iri", "iri-reference":
if _, err := url.Parse(str); err != nil {
res.Addf(path, str, "expected string to be RFC 3986 uri: %v", err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC3986URI, err))
}
// TODO: check if it's actually a reference?
case "uuid":
if err := validateUUID(str); err != nil {
res.Addf(path, str, "expected string to be RFC 4122 uuid: %v", err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC4122UUID, err))
}
case "uri-template":
u, err := url.Parse(str)
if err != nil {
res.Addf(path, str, "expected string to be RFC 3986 uri: %v", err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC3986URI, err))
return
}
if !rxURITemplate.MatchString(u.Path) {
res.Add(path, str, "expected string to be RFC 6570 uri-template")
res.Add(path, str, validation.MsgExpectedRFC6570URITemplate)
}
case "json-pointer":
if !rxJSONPointer.MatchString(str) {
res.Add(path, str, "expected string to be RFC 6901 json-pointer")
res.Add(path, str, validation.MsgExpectedRFC6901JSONPointer)
}
case "relative-json-pointer":
if !rxRelJSONPointer.MatchString(str) {
res.Add(path, str, "expected string to be RFC 6901 relative-json-pointer")
res.Add(path, str, validation.MsgExpectedRFC6901RelativeJSONPointer)
}
case "regex":
if _, err := regexp.Compile(str); err != nil {
res.Addf(path, str, "expected string to be regex: %v", err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRegexp, err))
}
}
}
Expand Down Expand Up @@ -289,7 +281,7 @@ func validateAnyOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v
}

if matches == 0 {
res.Add(path, v, "expected value to match at least one schema but matched none")
res.Add(path, v, validation.MsgExpectedMatchSchema)
}
}

Expand Down Expand Up @@ -333,7 +325,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
subRes := &ValidateResult{}
Validate(r, s.Not, path, mode, v, subRes)
if len(subRes.Errors) == 0 {
res.Add(path, v, "expected value to not match schema")
res.Add(path, v, validation.MsgExpectedNotMatchSchema)
}
}

Expand All @@ -344,7 +336,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
switch s.Type {
case TypeBoolean:
if _, ok := v.(bool); !ok {
res.Add(path, v, "expected boolean")
res.Add(path, v, validation.MsgExpectedBoolean)
return
}
case TypeNumber, TypeInteger:
Expand Down Expand Up @@ -376,18 +368,18 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
case uint64:
num = float64(v)
default:
res.Add(path, v, "expected number")
res.Add(path, v, validation.MsgExpectedNumber)
return
}

if s.Minimum != nil {
if num < *s.Minimum {
res.Addf(path, v, s.msgMinimum)
res.Add(path, v, s.msgMinimum)
}
}
if s.ExclusiveMinimum != nil {
if num <= *s.ExclusiveMinimum {
res.Addf(path, v, s.msgExclusiveMinimum)
res.Add(path, v, s.msgExclusiveMinimum)
}
}
if s.Maximum != nil {
Expand All @@ -397,12 +389,12 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
}
if s.ExclusiveMaximum != nil {
if num >= *s.ExclusiveMaximum {
res.Addf(path, v, s.msgExclusiveMaximum)
res.Add(path, v, s.msgExclusiveMaximum)
}
}
if s.MultipleOf != nil {
if math.Mod(num, *s.MultipleOf) != 0 {
res.Addf(path, v, s.msgMultipleOf)
res.Add(path, v, s.msgMultipleOf)
}
}
case TypeString:
Expand All @@ -411,14 +403,14 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
if b, ok := v.([]byte); ok {
str = *(*string)(unsafe.Pointer(&b))
} else {
res.Add(path, v, "expected string")
res.Add(path, v, validation.MsgExpectedString)
return
}
}

if s.MinLength != nil {
if utf8.RuneCountInString(str) < *s.MinLength {
res.Addf(path, str, s.msgMinLength)
res.Add(path, str, s.msgMinLength)
}
}
if s.MaxLength != nil {
Expand All @@ -438,7 +430,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,

if s.ContentEncoding == "base64" {
if !rxBase64.MatchString(str) {
res.Add(path, str, "expected string to be base64 encoded")
res.Add(path, str, validation.MsgExpectedBase64String)
}
}
case TypeArray:
Expand Down Expand Up @@ -471,7 +463,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
case []float64:
handleArray(r, s, path, mode, res, arr)
default:
res.Add(path, v, "expected array")
res.Add(path, v, validation.MsgExpectedArray)
return
}
case TypeObject:
Expand All @@ -480,7 +472,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
} else if vv, ok := v.(map[any]any); ok {
handleMapAny(r, s, path, mode, vv, res)
} else {
res.Add(path, v, "expected object")
res.Add(path, v, validation.MsgExpectedObject)
return
}
}
Expand All @@ -502,20 +494,20 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
func handleArray[T any](r Registry, s *Schema, path *PathBuffer, mode ValidateMode, res *ValidateResult, arr []T) {
if s.MinItems != nil {
if len(arr) < *s.MinItems {
res.Addf(path, arr, s.msgMinItems)
res.Add(path, arr, s.msgMinItems)
}
}
if s.MaxItems != nil {
if len(arr) > *s.MaxItems {
res.Addf(path, arr, s.msgMaxItems)
res.Add(path, arr, s.msgMaxItems)
}
}

if s.UniqueItems {
seen := make(map[any]struct{}, len(arr))
for _, item := range arr {
if _, ok := seen[item]; ok {
res.Add(path, arr, "expected array items to be unique")
res.Add(path, arr, validation.MsgExpectedArrayItemsUnique)
}
seen[item] = struct{}{}
}
Expand Down Expand Up @@ -602,7 +594,7 @@ func handleMapString(r Registry, s *Schema, path *PathBuffer, mode ValidateMode,
// No additional properties allowed.
if _, ok := s.Properties[k]; !ok {
path.Push(k)
res.Add(path, m, "unexpected property")
res.Add(path, m, validation.MsgUnexpectedProperty)
path.Pop()
}
}
Expand Down Expand Up @@ -702,7 +694,7 @@ func handleMapAny(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, m
}
if _, ok := s.Properties[kStr]; !ok {
path.Push(kStr)
res.Add(path, m, "unexpected property")
res.Add(path, m, validation.MsgUnexpectedProperty)
path.Pop()
}
}
Expand Down
Loading
Loading