From 20a669ac17018f44fe375e430ee311184fb6c378 Mon Sep 17 00:00:00 2001 From: Maxim Sukharev Date: Wed, 24 Jul 2024 15:38:46 +0800 Subject: [PATCH 1/5] feat: make error messages customizable --- schema.go | 36 +++++++++++++++------------- validate.go | 54 ++++++++++++++++++++++-------------------- validation/messages.go | 44 ++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 43 deletions(-) create mode 100644 validation/messages.go diff --git a/schema.go b/schema.go index f65de67a..a57894a6 100644 --- a/schema.go +++ b/schema.go @@ -14,6 +14,8 @@ import ( "strconv" "strings" "time" + + "github.com/danielgtaylor/huma/v2/validation" ) // ErrSchemaInvalid is sent when there is a problem building the schema. @@ -202,49 +204,49 @@ func (s *Schema) MarshalJSON() ([]byte, error) { // PrecomputeMessages tries to precompute as many validation error messages // as possible so that new strings aren't allocated during request validation. func (s *Schema) PrecomputeMessages() { - s.msgEnum = "expected value to be one of \"" + strings.Join(mapTo(s.Enum, func(v any) string { + s.msgEnum = fmt.Sprintf(validation.MsgExpectedOneOf, strings.Join(mapTo(s.Enum, func(v any) string { return fmt.Sprintf("%v", v) - }), ", ") + "\"" + }), ", ")) if s.Minimum != nil { - s.msgMinimum = fmt.Sprintf("expected number >= %v", *s.Minimum) + s.msgMinimum = fmt.Sprintf(validation.MsgExpectedMinimumNumber, *s.Minimum) } if s.ExclusiveMinimum != nil { - s.msgExclusiveMinimum = fmt.Sprintf("expected number > %v", *s.ExclusiveMinimum) + s.msgExclusiveMinimum = fmt.Sprintf(validation.MsgExpectedExclusiveMinimumNumber, *s.ExclusiveMinimum) } if s.Maximum != nil { - s.msgMaximum = fmt.Sprintf("expected number <= %v", *s.Maximum) + s.msgMaximum = fmt.Sprintf(validation.MsgExpectedMaximumNumber, *s.Maximum) } if s.ExclusiveMaximum != nil { - s.msgExclusiveMaximum = fmt.Sprintf("expected number < %v", *s.ExclusiveMaximum) + s.msgExclusiveMaximum = fmt.Sprintf(validation.MsgExpectedExclusiveMaximumNumber, *s.ExclusiveMaximum) } if s.MultipleOf != nil { - s.msgMultipleOf = fmt.Sprintf("expected number to be a multiple of %v", *s.MultipleOf) + s.msgMultipleOf = fmt.Sprintf(validation.MsgExpectedNumberBeMultipleOf, *s.MultipleOf) } if s.MinLength != nil { - s.msgMinLength = fmt.Sprintf("expected length >= %d", *s.MinLength) + s.msgMinLength = fmt.Sprintf(validation.MsgExpectedMinLength, *s.MinLength) } if s.MaxLength != nil { - s.msgMaxLength = fmt.Sprintf("expected length <= %d", *s.MaxLength) + s.msgMaxLength = fmt.Sprintf(validation.MsgExpectedMaxLength, *s.MaxLength) } if s.Pattern != "" { s.patternRe = regexp.MustCompile(s.Pattern) if s.PatternDescription != "" { - s.msgPattern = "expected string to be " + s.PatternDescription + s.msgPattern = fmt.Sprintf(validation.MsgExpectedBePattern, s.PatternDescription) } else { - s.msgPattern = "expected string to match pattern " + s.Pattern + s.msgPattern = fmt.Sprintf(validation.MsgExpectedMatchPattern, s.Pattern) } } if s.MinItems != nil { - s.msgMinItems = fmt.Sprintf("expected array length >= %d", *s.MinItems) + s.msgMinItems = fmt.Sprintf(validation.MsgExpectedMinItems, *s.MinItems) } if s.MaxItems != nil { - s.msgMaxItems = fmt.Sprintf("expected array length <= %d", *s.MaxItems) + s.msgMaxItems = fmt.Sprintf(validation.MsgExpectedMaxItems, *s.MaxItems) } if s.MinProperties != nil { - s.msgMinProperties = fmt.Sprintf("expected object with at least %d properties", *s.MinProperties) + s.msgMinProperties = fmt.Sprintf(validation.MsgExpectedMinProperties, *s.MinProperties) } if s.MaxProperties != nil { - s.msgMaxProperties = fmt.Sprintf("expected object with at most %d properties", *s.MaxProperties) + s.msgMaxProperties = fmt.Sprintf(validation.MsgExpectedMaxProperties, *s.MaxProperties) } if s.Required != nil { @@ -252,7 +254,7 @@ func (s *Schema) PrecomputeMessages() { s.msgRequired = map[string]string{} } for _, name := range s.Required { - s.msgRequired[name] = "expected required property " + name + " to be present" + s.msgRequired[name] = fmt.Sprintf(validation.MsgExpectedRequiredProperty, name) } } @@ -265,7 +267,7 @@ func (s *Schema) PrecomputeMessages() { if s.msgDependentRequired[name] == nil { s.msgDependentRequired[name] = map[string]string{} } - s.msgDependentRequired[name][dependent] = "expected property " + dependent + " to be present when " + name + " is present" + s.msgDependentRequired[name][dependent] = fmt.Sprintf(validation.MsgExpectedDependentRequiredProperty, dependent, name) } } } diff --git a/validate.go b/validate.go index 9c404ac9..187f8d10 100644 --- a/validate.go +++ b/validate.go @@ -14,6 +14,8 @@ import ( "time" "unicode/utf8" "unsafe" + + "github.com/danielgtaylor/huma/v2/validation" ) // ValidateMode describes the direction of validation (server -> client or @@ -192,69 +194,69 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult } } if !found { - res.Add(path, str, "expected string to be RFC 3339 date-time") + res.Add(path, str, validation.MsgExpectedRFC3339DateTime) } case "date-time-http": if _, err := time.Parse(time.RFC1123, str); err != nil { - res.Add(path, str, "expected string to be RFC 1123 date-time") + res.Add(path, str, validation.MsgExpectedRFC1123DateTime) } case "date": if _, err := time.Parse("2006-01-02", str); err != nil { - res.Add(path, str, "expected string to be RFC 3339 date") + res.Add(path, str, validation.MsgExpectedRFC3339Date) } case "time": if _, err := time.Parse("15:04:05", str); err != nil { if _, err := time.Parse("15:04:05Z07:00", str); err != nil { - res.Add(path, str, "expected string to be RFC 3339 time") + res.Add(path, str, validation.MsgExpectedRFC3339Time) } } // TODO: duration case "email", "idn-email": if _, err := mail.ParseAddress(str); err != nil { - res.Addf(path, str, "expected string to be RFC 5322 email: %v", err) + res.Addf(path, str, validation.MsgExpectedRFC5322Email, err) } case "hostname": if !(rxHostname.MatchString(str) && len(str) < 256) { - res.Add(path, str, "expected string to be RFC 5890 hostname") + res.Add(path, str, validation.MsgExpectedRFC5890Hostname) } // TODO: proper idn-hostname support... need to figure out how. case "ipv4": if ip := net.ParseIP(str); ip == nil || ip.To4() == nil { - res.Add(path, str, "expected string to be RFC 2673 ipv4") + res.Add(path, str, validation.MsgExpectedRFC2673IPv4) } case "ipv6": if ip := net.ParseIP(str); ip == nil || ip.To16() == nil { - res.Add(path, str, "expected string to be RFC 2373 ipv6") + res.Add(path, str, validation.MsgExpectedRFC2373IPv6) } case "uri", "uri-reference", "iri", "iri-reference": if _, err := url.Parse(str); err != nil { - res.Addf(path, str, "expected string to be RFC 3986 uri: %v", err) + res.Addf(path, str, validation.MsgExpectedRFC3986URI, err) } // TODO: check if it's actually a reference? case "uuid": if err := validateUUID(str); err != nil { - res.Addf(path, str, "expected string to be RFC 4122 uuid: %v", err) + res.Addf(path, str, validation.MsgExpectedRFC4122UUID, err) } case "uri-template": u, err := url.Parse(str) if err != nil { - res.Addf(path, str, "expected string to be RFC 3986 uri: %v", err) + res.Addf(path, str, validation.MsgExpectedRFC3986URI, err) return } if !rxURITemplate.MatchString(u.Path) { - res.Add(path, str, "expected string to be RFC 6570 uri-template") + res.Add(path, str, validation.MsgExpectedRFC6570URITemplate) } case "json-pointer": if !rxJSONPointer.MatchString(str) { - res.Add(path, str, "expected string to be RFC 6901 json-pointer") + res.Add(path, str, validation.MsgExpectedRFC6901JSONPointer) } case "relative-json-pointer": if !rxRelJSONPointer.MatchString(str) { - res.Add(path, str, "expected string to be RFC 6901 relative-json-pointer") + res.Add(path, str, validation.MsgExpectedRFC6901RelativeJSONPointer) } case "regex": if _, err := regexp.Compile(str); err != nil { - res.Addf(path, str, "expected string to be regex: %v", err) + res.Addf(path, str, validation.MsgExpectedRegexp, err) } } } @@ -289,7 +291,7 @@ func validateAnyOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v } if matches == 0 { - res.Add(path, v, "expected value to match at least one schema but matched none") + res.Add(path, v, validation.MsgExpectedMatchSchema) } } @@ -333,7 +335,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, subRes := &ValidateResult{} Validate(r, s.Not, path, mode, v, subRes) if len(subRes.Errors) == 0 { - res.Add(path, v, "expected value to not match schema") + res.Add(path, v, validation.MsgExpectedNotMatchSchema) } } @@ -344,7 +346,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, switch s.Type { case TypeBoolean: if _, ok := v.(bool); !ok { - res.Add(path, v, "expected boolean") + res.Add(path, v, validation.MsgExpectedBoolean) return } case TypeNumber, TypeInteger: @@ -376,7 +378,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, case uint64: num = float64(v) default: - res.Add(path, v, "expected number") + res.Add(path, v, validation.MsgExpectedNumber) return } @@ -411,7 +413,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, if b, ok := v.([]byte); ok { str = *(*string)(unsafe.Pointer(&b)) } else { - res.Add(path, v, "expected string") + res.Add(path, v, validation.MsgExpectedString) return } } @@ -438,7 +440,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, if s.ContentEncoding == "base64" { if !rxBase64.MatchString(str) { - res.Add(path, str, "expected string to be base64 encoded") + res.Add(path, str, validation.MsgExpectedBase64String) } } case TypeArray: @@ -471,7 +473,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, case []float64: handleArray(r, s, path, mode, res, arr) default: - res.Add(path, v, "expected array") + res.Add(path, v, validation.MsgExpectedArray) return } case TypeObject: @@ -480,7 +482,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, } else if vv, ok := v.(map[any]any); ok { handleMapAny(r, s, path, mode, vv, res) } else { - res.Add(path, v, "expected object") + res.Add(path, v, validation.MsgExpectedObject) return } } @@ -515,7 +517,7 @@ func handleArray[T any](r Registry, s *Schema, path *PathBuffer, mode ValidateMo seen := make(map[any]struct{}, len(arr)) for _, item := range arr { if _, ok := seen[item]; ok { - res.Add(path, arr, "expected array items to be unique") + res.Add(path, arr, validation.MsgExpectedArrayItemsUnique) } seen[item] = struct{}{} } @@ -602,7 +604,7 @@ func handleMapString(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, // No additional properties allowed. if _, ok := s.Properties[k]; !ok { path.Push(k) - res.Add(path, m, "unexpected property") + res.Add(path, m, validation.MsgUnexpectedProperty) path.Pop() } } @@ -702,7 +704,7 @@ func handleMapAny(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, m } if _, ok := s.Properties[kStr]; !ok { path.Push(kStr) - res.Add(path, m, "unexpected property") + res.Add(path, m, validation.MsgUnexpectedProperty) path.Pop() } } diff --git a/validation/messages.go b/validation/messages.go new file mode 100644 index 00000000..965d7bd5 --- /dev/null +++ b/validation/messages.go @@ -0,0 +1,44 @@ +package validation + +var ( + MsgUnexpectedProperty = "unexpected property" + MsgExpectedRFC3339DateTime = "expected string to be RFC 3339 date-time" + MsgExpectedRFC1123DateTime = "expected string to be RFC 1123 date-time" + MsgExpectedRFC3339Date = "expected string to be RFC 3339 date" + MsgExpectedRFC3339Time = "expected string to be RFC 3339 time" + MsgExpectedRFC5322Email = "expected string to be RFC 5322 email: %v" + MsgExpectedRFC5890Hostname = "expected string to be RFC 5890 hostname" + MsgExpectedRFC2673IPv4 = "expected string to be RFC 2673 ipv4" + MsgExpectedRFC2373IPv6 = "expected string to be RFC 2373 ipv6" + MsgExpectedRFC3986URI = "expected string to be RFC 3986 uri: %v" + MsgExpectedRFC4122UUID = "expected string to be RFC 4122 uuid: %v" + MsgExpectedRFC6570URITemplate = "expected string to be RFC 6570 uri-template" + MsgExpectedRFC6901JSONPointer = "expected string to be RFC 6901 json-pointer" + MsgExpectedRFC6901RelativeJSONPointer = "expected string to be RFC 6901 relative-json-pointer" + MsgExpectedRegexp = "expected string to be regex: %v" + MsgExpectedMatchSchema = "expected value to match at least one schema but matched none" + MsgExpectedNotMatchSchema = "expected value to not match schema" + MsgExpectedBoolean = "expected boolean" + MsgExpectedNumber = "expected number" + MsgExpectedString = "expected string" + MsgExpectedBase64String = "expected string to be base64 encoded" + MsgExpectedArray = "expected array" + MsgExpectedObject = "expected object" + MsgExpectedArrayItemsUnique = "expected array items to be unique" + MsgExpectedOneOf = "expected value to be one of \"%s\"" + MsgExpectedMinimumNumber = "expected number >= %v" + MsgExpectedExclusiveMinimumNumber = "expected number > %v" + MsgExpectedMaximumNumber = "expected number <= %v" + MsgExpectedExclusiveMaximumNumber = "expected number < %v" + MsgExpectedNumberBeMultipleOf = "expected number to be a multiple of %v" + MsgExpectedMinLength = "expected length >= %d" + MsgExpectedMaxLength = "expected length <= %d" + MsgExpectedBePattern = "expected string to be %s" + MsgExpectedMatchPattern = "expected string to match pattern %s" + MsgExpectedMinItems = "expected array length >= %d" + MsgExpectedMaxItems = "expected array length <= %d" + MsgExpectedMinProperties = "expected object with at least %d properties" + MsgExpectedMaxProperties = "expected object with at most %d properties" + MsgExpectedRequiredProperty = "expected required property %s to be present" + MsgExpectedDependentRequiredProperty = "expected property %s to be present when %s is present" +) From 9ff9231aaa94414b3a5354e7d8f083ab0ead965d Mon Sep 17 00:00:00 2001 From: Maxim Sukharev Date: Wed, 24 Jul 2024 16:23:44 +0800 Subject: [PATCH 2/5] feat: allow customizing error formatting func --- schema.go | 32 ++++++++++++++++---------------- validate.go | 34 ++++++++++++---------------------- validation/messages.go | 6 ++++++ 3 files changed, 34 insertions(+), 38 deletions(-) diff --git a/schema.go b/schema.go index a57894a6..95d4a517 100644 --- a/schema.go +++ b/schema.go @@ -204,49 +204,49 @@ func (s *Schema) MarshalJSON() ([]byte, error) { // PrecomputeMessages tries to precompute as many validation error messages // as possible so that new strings aren't allocated during request validation. func (s *Schema) PrecomputeMessages() { - s.msgEnum = fmt.Sprintf(validation.MsgExpectedOneOf, strings.Join(mapTo(s.Enum, func(v any) string { + s.msgEnum = validation.ErrorFormatter(validation.MsgExpectedOneOf, strings.Join(mapTo(s.Enum, func(v any) string { return fmt.Sprintf("%v", v) }), ", ")) if s.Minimum != nil { - s.msgMinimum = fmt.Sprintf(validation.MsgExpectedMinimumNumber, *s.Minimum) + s.msgMinimum = validation.ErrorFormatter(validation.MsgExpectedMinimumNumber, *s.Minimum) } if s.ExclusiveMinimum != nil { - s.msgExclusiveMinimum = fmt.Sprintf(validation.MsgExpectedExclusiveMinimumNumber, *s.ExclusiveMinimum) + s.msgExclusiveMinimum = validation.ErrorFormatter(validation.MsgExpectedExclusiveMinimumNumber, *s.ExclusiveMinimum) } if s.Maximum != nil { - s.msgMaximum = fmt.Sprintf(validation.MsgExpectedMaximumNumber, *s.Maximum) + s.msgMaximum = validation.ErrorFormatter(validation.MsgExpectedMaximumNumber, *s.Maximum) } if s.ExclusiveMaximum != nil { - s.msgExclusiveMaximum = fmt.Sprintf(validation.MsgExpectedExclusiveMaximumNumber, *s.ExclusiveMaximum) + s.msgExclusiveMaximum = validation.ErrorFormatter(validation.MsgExpectedExclusiveMaximumNumber, *s.ExclusiveMaximum) } if s.MultipleOf != nil { - s.msgMultipleOf = fmt.Sprintf(validation.MsgExpectedNumberBeMultipleOf, *s.MultipleOf) + s.msgMultipleOf = validation.ErrorFormatter(validation.MsgExpectedNumberBeMultipleOf, *s.MultipleOf) } if s.MinLength != nil { - s.msgMinLength = fmt.Sprintf(validation.MsgExpectedMinLength, *s.MinLength) + s.msgMinLength = validation.ErrorFormatter(validation.MsgExpectedMinLength, *s.MinLength) } if s.MaxLength != nil { - s.msgMaxLength = fmt.Sprintf(validation.MsgExpectedMaxLength, *s.MaxLength) + s.msgMaxLength = validation.ErrorFormatter(validation.MsgExpectedMaxLength, *s.MaxLength) } if s.Pattern != "" { s.patternRe = regexp.MustCompile(s.Pattern) if s.PatternDescription != "" { - s.msgPattern = fmt.Sprintf(validation.MsgExpectedBePattern, s.PatternDescription) + s.msgPattern = validation.ErrorFormatter(validation.MsgExpectedBePattern, s.PatternDescription) } else { - s.msgPattern = fmt.Sprintf(validation.MsgExpectedMatchPattern, s.Pattern) + s.msgPattern = validation.ErrorFormatter(validation.MsgExpectedMatchPattern, s.Pattern) } } if s.MinItems != nil { - s.msgMinItems = fmt.Sprintf(validation.MsgExpectedMinItems, *s.MinItems) + s.msgMinItems = validation.ErrorFormatter(validation.MsgExpectedMinItems, *s.MinItems) } if s.MaxItems != nil { - s.msgMaxItems = fmt.Sprintf(validation.MsgExpectedMaxItems, *s.MaxItems) + s.msgMaxItems = validation.ErrorFormatter(validation.MsgExpectedMaxItems, *s.MaxItems) } if s.MinProperties != nil { - s.msgMinProperties = fmt.Sprintf(validation.MsgExpectedMinProperties, *s.MinProperties) + s.msgMinProperties = validation.ErrorFormatter(validation.MsgExpectedMinProperties, *s.MinProperties) } if s.MaxProperties != nil { - s.msgMaxProperties = fmt.Sprintf(validation.MsgExpectedMaxProperties, *s.MaxProperties) + s.msgMaxProperties = validation.ErrorFormatter(validation.MsgExpectedMaxProperties, *s.MaxProperties) } if s.Required != nil { @@ -254,7 +254,7 @@ func (s *Schema) PrecomputeMessages() { s.msgRequired = map[string]string{} } for _, name := range s.Required { - s.msgRequired[name] = fmt.Sprintf(validation.MsgExpectedRequiredProperty, name) + s.msgRequired[name] = validation.ErrorFormatter(validation.MsgExpectedRequiredProperty, name) } } @@ -267,7 +267,7 @@ func (s *Schema) PrecomputeMessages() { if s.msgDependentRequired[name] == nil { s.msgDependentRequired[name] = map[string]string{} } - s.msgDependentRequired[name][dependent] = fmt.Sprintf(validation.MsgExpectedDependentRequiredProperty, dependent, name) + s.msgDependentRequired[name][dependent] = validation.ErrorFormatter(validation.MsgExpectedDependentRequiredProperty, dependent, name) } } } diff --git a/validate.go b/validate.go index 187f8d10..02895d0f 100644 --- a/validate.go +++ b/validate.go @@ -168,16 +168,6 @@ func (r *ValidateResult) Add(path *PathBuffer, v any, msg string) { }) } -// Addf adds an error to the validation result at the given path and with -// the given value, allowing for fmt.Printf-style formatting. -func (r *ValidateResult) Addf(path *PathBuffer, v any, format string, args ...any) { - r.Errors = append(r.Errors, &ErrorDetail{ - Message: fmt.Sprintf(format, args...), - Location: path.String(), - Value: v, - }) -} - // Reset the validation error so it can be used again. func (r *ValidateResult) Reset() { r.Errors = r.Errors[:0] @@ -213,7 +203,7 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult // TODO: duration case "email", "idn-email": if _, err := mail.ParseAddress(str); err != nil { - res.Addf(path, str, validation.MsgExpectedRFC5322Email, err) + res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC5322Email, err)) } case "hostname": if !(rxHostname.MatchString(str) && len(str) < 256) { @@ -230,17 +220,17 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult } case "uri", "uri-reference", "iri", "iri-reference": if _, err := url.Parse(str); err != nil { - res.Addf(path, str, validation.MsgExpectedRFC3986URI, err) + res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC3986URI, err)) } // TODO: check if it's actually a reference? case "uuid": if err := validateUUID(str); err != nil { - res.Addf(path, str, validation.MsgExpectedRFC4122UUID, err) + res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC4122UUID, err)) } case "uri-template": u, err := url.Parse(str) if err != nil { - res.Addf(path, str, validation.MsgExpectedRFC3986URI, err) + res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC3986URI, err)) return } if !rxURITemplate.MatchString(u.Path) { @@ -256,7 +246,7 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult } case "regex": if _, err := regexp.Compile(str); err != nil { - res.Addf(path, str, validation.MsgExpectedRegexp, err) + res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRegexp, err)) } } } @@ -384,12 +374,12 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, if s.Minimum != nil { if num < *s.Minimum { - res.Addf(path, v, s.msgMinimum) + res.Add(path, v, s.msgMinimum) } } if s.ExclusiveMinimum != nil { if num <= *s.ExclusiveMinimum { - res.Addf(path, v, s.msgExclusiveMinimum) + res.Add(path, v, s.msgExclusiveMinimum) } } if s.Maximum != nil { @@ -399,12 +389,12 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, } if s.ExclusiveMaximum != nil { if num >= *s.ExclusiveMaximum { - res.Addf(path, v, s.msgExclusiveMaximum) + res.Add(path, v, s.msgExclusiveMaximum) } } if s.MultipleOf != nil { if math.Mod(num, *s.MultipleOf) != 0 { - res.Addf(path, v, s.msgMultipleOf) + res.Add(path, v, s.msgMultipleOf) } } case TypeString: @@ -420,7 +410,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, if s.MinLength != nil { if utf8.RuneCountInString(str) < *s.MinLength { - res.Addf(path, str, s.msgMinLength) + res.Add(path, str, s.msgMinLength) } } if s.MaxLength != nil { @@ -504,12 +494,12 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, func handleArray[T any](r Registry, s *Schema, path *PathBuffer, mode ValidateMode, res *ValidateResult, arr []T) { if s.MinItems != nil { if len(arr) < *s.MinItems { - res.Addf(path, arr, s.msgMinItems) + res.Add(path, arr, s.msgMinItems) } } if s.MaxItems != nil { if len(arr) > *s.MaxItems { - res.Addf(path, arr, s.msgMaxItems) + res.Add(path, arr, s.msgMaxItems) } } diff --git a/validation/messages.go b/validation/messages.go index 965d7bd5..e7b6f3cc 100644 --- a/validation/messages.go +++ b/validation/messages.go @@ -1,5 +1,8 @@ package validation +import "fmt" + +// List of built-in validation error messages var ( MsgUnexpectedProperty = "unexpected property" MsgExpectedRFC3339DateTime = "expected string to be RFC 3339 date-time" @@ -42,3 +45,6 @@ var ( MsgExpectedRequiredProperty = "expected required property %s to be present" MsgExpectedDependentRequiredProperty = "expected property %s to be present when %s is present" ) + +// ErrorFormatter is a function that formats an error message +var ErrorFormatter func(format string, a ...any) string = fmt.Sprintf From 084c351d723a3a7acfbb5180f9ee03f43172758f Mon Sep 17 00:00:00 2001 From: Maxim Sukharev Date: Fri, 26 Jul 2024 10:55:33 +0800 Subject: [PATCH 3/5] fix: don't remove ValidateResult.Addf to avoid breaking change --- validate.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/validate.go b/validate.go index 02895d0f..c60e75e5 100644 --- a/validate.go +++ b/validate.go @@ -168,6 +168,16 @@ func (r *ValidateResult) Add(path *PathBuffer, v any, msg string) { }) } +// Addf adds an error to the validation result at the given path and with +// the given value, allowing for fmt.Printf-style formatting. +func (r *ValidateResult) Addf(path *PathBuffer, v any, format string, args ...any) { + r.Errors = append(r.Errors, &ErrorDetail{ + Message: fmt.Sprintf(format, args...), + Location: path.String(), + Value: v, + }) +} + // Reset the validation error so it can be used again. func (r *ValidateResult) Reset() { r.Errors = r.Errors[:0] From 8033e05efe39f138bc1f9f466dd6f3a97fd21244 Mon Sep 17 00:00:00 2001 From: Maxim Sukharev Date: Fri, 26 Jul 2024 10:56:04 +0800 Subject: [PATCH 4/5] chore: move ErrorFormatter to the root of huma --- error.go | 3 +++ schema.go | 32 ++++++++++++++++---------------- validate.go | 10 +++++----- validation/messages.go | 5 ----- 4 files changed, 24 insertions(+), 26 deletions(-) diff --git a/error.go b/error.go index 67017c29..51f091c5 100644 --- a/error.go +++ b/error.go @@ -362,3 +362,6 @@ func Error503ServiceUnavailable(msg string, errs ...error) StatusError { func Error504GatewayTimeout(msg string, errs ...error) StatusError { return NewError(http.StatusGatewayTimeout, msg, errs...) } + +// ErrorFormatter is a function that formats an error message +var ErrorFormatter func(format string, a ...any) string = fmt.Sprintf diff --git a/schema.go b/schema.go index 95d4a517..6d2f4064 100644 --- a/schema.go +++ b/schema.go @@ -204,49 +204,49 @@ func (s *Schema) MarshalJSON() ([]byte, error) { // PrecomputeMessages tries to precompute as many validation error messages // as possible so that new strings aren't allocated during request validation. func (s *Schema) PrecomputeMessages() { - s.msgEnum = validation.ErrorFormatter(validation.MsgExpectedOneOf, strings.Join(mapTo(s.Enum, func(v any) string { + s.msgEnum = ErrorFormatter(validation.MsgExpectedOneOf, strings.Join(mapTo(s.Enum, func(v any) string { return fmt.Sprintf("%v", v) }), ", ")) if s.Minimum != nil { - s.msgMinimum = validation.ErrorFormatter(validation.MsgExpectedMinimumNumber, *s.Minimum) + s.msgMinimum = ErrorFormatter(validation.MsgExpectedMinimumNumber, *s.Minimum) } if s.ExclusiveMinimum != nil { - s.msgExclusiveMinimum = validation.ErrorFormatter(validation.MsgExpectedExclusiveMinimumNumber, *s.ExclusiveMinimum) + s.msgExclusiveMinimum = ErrorFormatter(validation.MsgExpectedExclusiveMinimumNumber, *s.ExclusiveMinimum) } if s.Maximum != nil { - s.msgMaximum = validation.ErrorFormatter(validation.MsgExpectedMaximumNumber, *s.Maximum) + s.msgMaximum = ErrorFormatter(validation.MsgExpectedMaximumNumber, *s.Maximum) } if s.ExclusiveMaximum != nil { - s.msgExclusiveMaximum = validation.ErrorFormatter(validation.MsgExpectedExclusiveMaximumNumber, *s.ExclusiveMaximum) + s.msgExclusiveMaximum = ErrorFormatter(validation.MsgExpectedExclusiveMaximumNumber, *s.ExclusiveMaximum) } if s.MultipleOf != nil { - s.msgMultipleOf = validation.ErrorFormatter(validation.MsgExpectedNumberBeMultipleOf, *s.MultipleOf) + s.msgMultipleOf = ErrorFormatter(validation.MsgExpectedNumberBeMultipleOf, *s.MultipleOf) } if s.MinLength != nil { - s.msgMinLength = validation.ErrorFormatter(validation.MsgExpectedMinLength, *s.MinLength) + s.msgMinLength = ErrorFormatter(validation.MsgExpectedMinLength, *s.MinLength) } if s.MaxLength != nil { - s.msgMaxLength = validation.ErrorFormatter(validation.MsgExpectedMaxLength, *s.MaxLength) + s.msgMaxLength = ErrorFormatter(validation.MsgExpectedMaxLength, *s.MaxLength) } if s.Pattern != "" { s.patternRe = regexp.MustCompile(s.Pattern) if s.PatternDescription != "" { - s.msgPattern = validation.ErrorFormatter(validation.MsgExpectedBePattern, s.PatternDescription) + s.msgPattern = ErrorFormatter(validation.MsgExpectedBePattern, s.PatternDescription) } else { - s.msgPattern = validation.ErrorFormatter(validation.MsgExpectedMatchPattern, s.Pattern) + s.msgPattern = ErrorFormatter(validation.MsgExpectedMatchPattern, s.Pattern) } } if s.MinItems != nil { - s.msgMinItems = validation.ErrorFormatter(validation.MsgExpectedMinItems, *s.MinItems) + s.msgMinItems = ErrorFormatter(validation.MsgExpectedMinItems, *s.MinItems) } if s.MaxItems != nil { - s.msgMaxItems = validation.ErrorFormatter(validation.MsgExpectedMaxItems, *s.MaxItems) + s.msgMaxItems = ErrorFormatter(validation.MsgExpectedMaxItems, *s.MaxItems) } if s.MinProperties != nil { - s.msgMinProperties = validation.ErrorFormatter(validation.MsgExpectedMinProperties, *s.MinProperties) + s.msgMinProperties = ErrorFormatter(validation.MsgExpectedMinProperties, *s.MinProperties) } if s.MaxProperties != nil { - s.msgMaxProperties = validation.ErrorFormatter(validation.MsgExpectedMaxProperties, *s.MaxProperties) + s.msgMaxProperties = ErrorFormatter(validation.MsgExpectedMaxProperties, *s.MaxProperties) } if s.Required != nil { @@ -254,7 +254,7 @@ func (s *Schema) PrecomputeMessages() { s.msgRequired = map[string]string{} } for _, name := range s.Required { - s.msgRequired[name] = validation.ErrorFormatter(validation.MsgExpectedRequiredProperty, name) + s.msgRequired[name] = ErrorFormatter(validation.MsgExpectedRequiredProperty, name) } } @@ -267,7 +267,7 @@ func (s *Schema) PrecomputeMessages() { if s.msgDependentRequired[name] == nil { s.msgDependentRequired[name] = map[string]string{} } - s.msgDependentRequired[name][dependent] = validation.ErrorFormatter(validation.MsgExpectedDependentRequiredProperty, dependent, name) + s.msgDependentRequired[name][dependent] = ErrorFormatter(validation.MsgExpectedDependentRequiredProperty, dependent, name) } } } diff --git a/validate.go b/validate.go index c60e75e5..47da6c65 100644 --- a/validate.go +++ b/validate.go @@ -213,7 +213,7 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult // TODO: duration case "email", "idn-email": if _, err := mail.ParseAddress(str); err != nil { - res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC5322Email, err)) + res.Add(path, str, ErrorFormatter(validation.MsgExpectedRFC5322Email, err)) } case "hostname": if !(rxHostname.MatchString(str) && len(str) < 256) { @@ -230,17 +230,17 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult } case "uri", "uri-reference", "iri", "iri-reference": if _, err := url.Parse(str); err != nil { - res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC3986URI, err)) + res.Add(path, str, ErrorFormatter(validation.MsgExpectedRFC3986URI, err)) } // TODO: check if it's actually a reference? case "uuid": if err := validateUUID(str); err != nil { - res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC4122UUID, err)) + res.Add(path, str, ErrorFormatter(validation.MsgExpectedRFC4122UUID, err)) } case "uri-template": u, err := url.Parse(str) if err != nil { - res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC3986URI, err)) + res.Add(path, str, ErrorFormatter(validation.MsgExpectedRFC3986URI, err)) return } if !rxURITemplate.MatchString(u.Path) { @@ -256,7 +256,7 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult } case "regex": if _, err := regexp.Compile(str); err != nil { - res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRegexp, err)) + res.Add(path, str, ErrorFormatter(validation.MsgExpectedRegexp, err)) } } } diff --git a/validation/messages.go b/validation/messages.go index e7b6f3cc..cb68214c 100644 --- a/validation/messages.go +++ b/validation/messages.go @@ -1,7 +1,5 @@ package validation -import "fmt" - // List of built-in validation error messages var ( MsgUnexpectedProperty = "unexpected property" @@ -45,6 +43,3 @@ var ( MsgExpectedRequiredProperty = "expected required property %s to be present" MsgExpectedDependentRequiredProperty = "expected property %s to be present when %s is present" ) - -// ErrorFormatter is a function that formats an error message -var ErrorFormatter func(format string, a ...any) string = fmt.Sprintf From 71fac327c83f548056e8053b0de99ed93a559425 Mon Sep 17 00:00:00 2001 From: Maxim Sukharev Date: Fri, 26 Jul 2024 11:32:12 +0800 Subject: [PATCH 5/5] test: add unit test for custom formatter --- validate_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/validate_test.go b/validate_test.go index 735a2f9b..eeb3c62b 100644 --- a/validate_test.go +++ b/validate_test.go @@ -1300,6 +1300,28 @@ func TestValidate(t *testing.T) { } } +func TestValidateCustomFormatter(t *testing.T) { + originalFormatter := huma.ErrorFormatter + defer func() { + huma.ErrorFormatter = originalFormatter + }() + + huma.ErrorFormatter = func(format string, a ...any) string { + return fmt.Sprintf("custom: %v", a) + } + + registry := huma.NewMapRegistry("#/components/schemas/", huma.DefaultSchemaNamer) + s := registry.Schema(reflect.TypeOf(struct { + Value string `json:"value" format:"email"` + }{}), true, "TestInput") + pb := huma.NewPathBuffer([]byte(""), 0) + res := &huma.ValidateResult{} + + huma.Validate(registry, s, pb, huma.ModeReadFromServer, map[string]any{"value": "alice"}, res) + assert.Len(t, res.Errors, 1) + assert.Equal(t, "custom: [mail: missing '@' or angle-addr] (value: alice)", res.Errors[0].Error()) +} + func ExampleModelValidator() { // Define a type you want to validate. type Model struct {