Skip to content

Commit

Permalink
feat: allow customizing error formatting func
Browse files Browse the repository at this point in the history
  • Loading branch information
smacker committed Jul 24, 2024
1 parent 20a669a commit 9ff9231
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 38 deletions.
32 changes: 16 additions & 16 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,57 +204,57 @@ func (s *Schema) MarshalJSON() ([]byte, error) {
// PrecomputeMessages tries to precompute as many validation error messages
// as possible so that new strings aren't allocated during request validation.
func (s *Schema) PrecomputeMessages() {
s.msgEnum = fmt.Sprintf(validation.MsgExpectedOneOf, strings.Join(mapTo(s.Enum, func(v any) string {
s.msgEnum = validation.ErrorFormatter(validation.MsgExpectedOneOf, strings.Join(mapTo(s.Enum, func(v any) string {
return fmt.Sprintf("%v", v)
}), ", "))
if s.Minimum != nil {
s.msgMinimum = fmt.Sprintf(validation.MsgExpectedMinimumNumber, *s.Minimum)
s.msgMinimum = validation.ErrorFormatter(validation.MsgExpectedMinimumNumber, *s.Minimum)
}
if s.ExclusiveMinimum != nil {
s.msgExclusiveMinimum = fmt.Sprintf(validation.MsgExpectedExclusiveMinimumNumber, *s.ExclusiveMinimum)
s.msgExclusiveMinimum = validation.ErrorFormatter(validation.MsgExpectedExclusiveMinimumNumber, *s.ExclusiveMinimum)
}
if s.Maximum != nil {
s.msgMaximum = fmt.Sprintf(validation.MsgExpectedMaximumNumber, *s.Maximum)
s.msgMaximum = validation.ErrorFormatter(validation.MsgExpectedMaximumNumber, *s.Maximum)
}
if s.ExclusiveMaximum != nil {
s.msgExclusiveMaximum = fmt.Sprintf(validation.MsgExpectedExclusiveMaximumNumber, *s.ExclusiveMaximum)
s.msgExclusiveMaximum = validation.ErrorFormatter(validation.MsgExpectedExclusiveMaximumNumber, *s.ExclusiveMaximum)
}
if s.MultipleOf != nil {
s.msgMultipleOf = fmt.Sprintf(validation.MsgExpectedNumberBeMultipleOf, *s.MultipleOf)
s.msgMultipleOf = validation.ErrorFormatter(validation.MsgExpectedNumberBeMultipleOf, *s.MultipleOf)
}
if s.MinLength != nil {
s.msgMinLength = fmt.Sprintf(validation.MsgExpectedMinLength, *s.MinLength)
s.msgMinLength = validation.ErrorFormatter(validation.MsgExpectedMinLength, *s.MinLength)
}
if s.MaxLength != nil {
s.msgMaxLength = fmt.Sprintf(validation.MsgExpectedMaxLength, *s.MaxLength)
s.msgMaxLength = validation.ErrorFormatter(validation.MsgExpectedMaxLength, *s.MaxLength)
}
if s.Pattern != "" {
s.patternRe = regexp.MustCompile(s.Pattern)
if s.PatternDescription != "" {
s.msgPattern = fmt.Sprintf(validation.MsgExpectedBePattern, s.PatternDescription)
s.msgPattern = validation.ErrorFormatter(validation.MsgExpectedBePattern, s.PatternDescription)
} else {
s.msgPattern = fmt.Sprintf(validation.MsgExpectedMatchPattern, s.Pattern)
s.msgPattern = validation.ErrorFormatter(validation.MsgExpectedMatchPattern, s.Pattern)
}
}
if s.MinItems != nil {
s.msgMinItems = fmt.Sprintf(validation.MsgExpectedMinItems, *s.MinItems)
s.msgMinItems = validation.ErrorFormatter(validation.MsgExpectedMinItems, *s.MinItems)
}
if s.MaxItems != nil {
s.msgMaxItems = fmt.Sprintf(validation.MsgExpectedMaxItems, *s.MaxItems)
s.msgMaxItems = validation.ErrorFormatter(validation.MsgExpectedMaxItems, *s.MaxItems)
}
if s.MinProperties != nil {
s.msgMinProperties = fmt.Sprintf(validation.MsgExpectedMinProperties, *s.MinProperties)
s.msgMinProperties = validation.ErrorFormatter(validation.MsgExpectedMinProperties, *s.MinProperties)
}
if s.MaxProperties != nil {
s.msgMaxProperties = fmt.Sprintf(validation.MsgExpectedMaxProperties, *s.MaxProperties)
s.msgMaxProperties = validation.ErrorFormatter(validation.MsgExpectedMaxProperties, *s.MaxProperties)
}

if s.Required != nil {
if s.msgRequired == nil {
s.msgRequired = map[string]string{}
}
for _, name := range s.Required {
s.msgRequired[name] = fmt.Sprintf(validation.MsgExpectedRequiredProperty, name)
s.msgRequired[name] = validation.ErrorFormatter(validation.MsgExpectedRequiredProperty, name)
}
}

Expand All @@ -267,7 +267,7 @@ func (s *Schema) PrecomputeMessages() {
if s.msgDependentRequired[name] == nil {
s.msgDependentRequired[name] = map[string]string{}
}
s.msgDependentRequired[name][dependent] = fmt.Sprintf(validation.MsgExpectedDependentRequiredProperty, dependent, name)
s.msgDependentRequired[name][dependent] = validation.ErrorFormatter(validation.MsgExpectedDependentRequiredProperty, dependent, name)
}
}
}
Expand Down
34 changes: 12 additions & 22 deletions validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,6 @@ func (r *ValidateResult) Add(path *PathBuffer, v any, msg string) {
})
}

// Addf adds an error to the validation result at the given path and with
// the given value, allowing for fmt.Printf-style formatting.
func (r *ValidateResult) Addf(path *PathBuffer, v any, format string, args ...any) {
r.Errors = append(r.Errors, &ErrorDetail{
Message: fmt.Sprintf(format, args...),
Location: path.String(),
Value: v,
})
}

// Reset the validation error so it can be used again.
func (r *ValidateResult) Reset() {
r.Errors = r.Errors[:0]
Expand Down Expand Up @@ -213,7 +203,7 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult
// TODO: duration
case "email", "idn-email":
if _, err := mail.ParseAddress(str); err != nil {
res.Addf(path, str, validation.MsgExpectedRFC5322Email, err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC5322Email, err))
}
case "hostname":
if !(rxHostname.MatchString(str) && len(str) < 256) {
Expand All @@ -230,17 +220,17 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult
}
case "uri", "uri-reference", "iri", "iri-reference":
if _, err := url.Parse(str); err != nil {
res.Addf(path, str, validation.MsgExpectedRFC3986URI, err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC3986URI, err))
}
// TODO: check if it's actually a reference?
case "uuid":
if err := validateUUID(str); err != nil {
res.Addf(path, str, validation.MsgExpectedRFC4122UUID, err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC4122UUID, err))
}
case "uri-template":
u, err := url.Parse(str)
if err != nil {
res.Addf(path, str, validation.MsgExpectedRFC3986URI, err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRFC3986URI, err))
return
}
if !rxURITemplate.MatchString(u.Path) {
Expand All @@ -256,7 +246,7 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult
}
case "regex":
if _, err := regexp.Compile(str); err != nil {
res.Addf(path, str, validation.MsgExpectedRegexp, err)
res.Add(path, str, validation.ErrorFormatter(validation.MsgExpectedRegexp, err))
}
}
}
Expand Down Expand Up @@ -384,12 +374,12 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,

if s.Minimum != nil {
if num < *s.Minimum {
res.Addf(path, v, s.msgMinimum)
res.Add(path, v, s.msgMinimum)
}
}
if s.ExclusiveMinimum != nil {
if num <= *s.ExclusiveMinimum {
res.Addf(path, v, s.msgExclusiveMinimum)
res.Add(path, v, s.msgExclusiveMinimum)
}
}
if s.Maximum != nil {
Expand All @@ -399,12 +389,12 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
}
if s.ExclusiveMaximum != nil {
if num >= *s.ExclusiveMaximum {
res.Addf(path, v, s.msgExclusiveMaximum)
res.Add(path, v, s.msgExclusiveMaximum)
}
}
if s.MultipleOf != nil {
if math.Mod(num, *s.MultipleOf) != 0 {
res.Addf(path, v, s.msgMultipleOf)
res.Add(path, v, s.msgMultipleOf)
}
}
case TypeString:
Expand All @@ -420,7 +410,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,

if s.MinLength != nil {
if utf8.RuneCountInString(str) < *s.MinLength {
res.Addf(path, str, s.msgMinLength)
res.Add(path, str, s.msgMinLength)
}
}
if s.MaxLength != nil {
Expand Down Expand Up @@ -504,12 +494,12 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
func handleArray[T any](r Registry, s *Schema, path *PathBuffer, mode ValidateMode, res *ValidateResult, arr []T) {
if s.MinItems != nil {
if len(arr) < *s.MinItems {
res.Addf(path, arr, s.msgMinItems)
res.Add(path, arr, s.msgMinItems)
}
}
if s.MaxItems != nil {
if len(arr) > *s.MaxItems {
res.Addf(path, arr, s.msgMaxItems)
res.Add(path, arr, s.msgMaxItems)
}
}

Expand Down
6 changes: 6 additions & 0 deletions validation/messages.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package validation

import "fmt"

// List of built-in validation error messages
var (
MsgUnexpectedProperty = "unexpected property"
MsgExpectedRFC3339DateTime = "expected string to be RFC 3339 date-time"
Expand Down Expand Up @@ -42,3 +45,6 @@ var (
MsgExpectedRequiredProperty = "expected required property %s to be present"
MsgExpectedDependentRequiredProperty = "expected property %s to be present when %s is present"
)

// ErrorFormatter is a function that formats an error message
var ErrorFormatter func(format string, a ...any) string = fmt.Sprintf

0 comments on commit 9ff9231

Please sign in to comment.