diff --git a/functions.go b/functions.go index 4823ab57..952a6739 100644 --- a/functions.go +++ b/functions.go @@ -351,18 +351,17 @@ func Is(err, target error) bool { return stderrors.Is(err, target) } -// IsType is a convenience method for ascertaining if an error contains the -// target error type within its chain. This is aimed at ease of development -// where a more complicated error type wants to be to checked for existence but -// pointer var of that type is too much overhead. -func IsType[t error](err error) bool { - for err != nil { - if _, is := err.(t); is { - return true - } - err = stderrors.Unwrap(err) - } - return false +// HasType is a function wrapper around AsType dropping the where return value +// from AsType() making a function that can be used like this: +// +// return HasType[*MyError](err) +// +// Or +// +// if HasType[*MyError](err) {} +func HasType[T error](err error) bool { + _, rval := AsType[T](err) + return rval } // As is a proxy for the As function in Go's standard `errors` library @@ -371,6 +370,44 @@ func As(err error, target interface{}) bool { return stderrors.As(err, target) } +// AsType is a convenience method for checking and getting an error from within +// a chain that is of type T. If no error is found of type T in the chain the +// zero value of T is returned with false. If an error in the chain implementes +// As(any) bool then it's As method will be called if it's type is not of type T. + +// AsType finds the first error in err's chain that is assignable to type T, and +// if a match is found, returns that error value and true. Otherwise, it returns +// T's zero value and false. +// +// AsType is equivalent to errors.As, but uses a type parameter and returns +// the target, to avoid having to define a variable before the call. For +// example, callers can replace this: +// +// var pathError *fs.PathError +// if errors.As(err, &pathError) { +// fmt.Println("Failed at path:", pathError.Path) +// } +// +// With: +// +// if pathError, ok := errors.AsType[*fs.PathError](err); ok { +// fmt.Println("Failed at path:", pathError.Path) +// } +func AsType[T error](err error) (T, bool) { + for err != nil { + if e, is := err.(T); is { + return e, true + } + var res T + if x, ok := err.(interface{ As(any) bool }); ok && x.As(&res) { + return res, true + } + err = stderrors.Unwrap(err) + } + var zero T + return zero, false +} + // SetLocation takes a given error and records where in the stack SetLocation // was called from and returns the wrapped error with the location information // set. The returned error implements the Locationer interface. If err is nil diff --git a/functions_test.go b/functions_test.go index cd9502e0..9f1a11aa 100644 --- a/functions_test.go +++ b/functions_test.go @@ -403,11 +403,9 @@ func (*functionSuite) TestQuietWrappedErrorStillSatisfied(c *gc.C) { c.Assert(errors.Is(err, simpleTestError), gc.Equals, true) } -type FooError struct { -} - -func (*FooError) Error() string { - return "I am here boss" +type ComplexErrorMessage interface { + error + ComplexMessage() string } type complexError struct { @@ -418,30 +416,90 @@ func (c *complexError) Error() string { return c.Message } +func (c *complexError) ComplexMessage() string { + return c.Message +} + type complexErrorOther struct { Message string } +func (c *complexErrorOther) As(e any) bool { + if ce, ok := e.(**complexError); ok { + *ce = &complexError{ + Message: c.Message, + } + return true + } + return false +} + func (c *complexErrorOther) Error() string { return c.Message } -func (*functionSuite) TestIsType(c *gc.C) { +func (c *complexErrorOther) ComplexMessage() string { + return c.Message +} + +func (*functionSuite) TestHasType(c *gc.C) { complexErr := &complexError{Message: "complex error message"} wrapped1 := fmt.Errorf("wrapping1: %w", complexErr) wrapped2 := fmt.Errorf("wrapping2: %w", wrapped1) - c.Assert(errors.IsType[*complexError](complexErr), gc.Equals, true) - c.Assert(errors.IsType[*complexError](wrapped1), gc.Equals, true) - c.Assert(errors.IsType[*complexError](wrapped2), gc.Equals, true) - c.Assert(errors.IsType[*complexErrorOther](complexErr), gc.Equals, false) - c.Assert(errors.IsType[*complexErrorOther](wrapped1), gc.Equals, false) - c.Assert(errors.IsType[*complexErrorOther](wrapped2), gc.Equals, false) + c.Assert(errors.HasType[*complexError](complexErr), gc.Equals, true) + c.Assert(errors.HasType[*complexError](wrapped1), gc.Equals, true) + c.Assert(errors.HasType[*complexError](wrapped2), gc.Equals, true) + c.Assert(errors.HasType[ComplexErrorMessage](wrapped2), gc.Equals, true) + c.Assert(errors.HasType[*complexErrorOther](wrapped2), gc.Equals, false) + c.Assert(errors.HasType[*complexErrorOther](nil), gc.Equals, false) - err := errors.New("test") - c.Assert(errors.IsType[*complexErrorOther](err), gc.Equals, false) + complexErrOther := &complexErrorOther{Message: "another complex error"} + + c.Assert(errors.HasType[*complexError](complexErrOther), gc.Equals, true) - c.Assert(errors.IsType[*complexErrorOther](nil), gc.Equals, false) + wrapped2 = fmt.Errorf("wrapping1: %w", complexErrOther) + c.Assert(errors.HasType[*complexError](wrapped2), gc.Equals, true) +} + +func (*functionSuite) TestAsType(c *gc.C) { + complexErr := &complexError{Message: "complex error message"} + wrapped1 := fmt.Errorf("wrapping1: %w", complexErr) + wrapped2 := fmt.Errorf("wrapping2: %w", wrapped1) + + ce, ok := errors.AsType[*complexError](complexErr) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErr.Message) + + ce, ok = errors.AsType[*complexError](wrapped1) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErr.Message) + + ce, ok = errors.AsType[*complexError](wrapped2) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErr.Message) + + cem, ok := errors.AsType[ComplexErrorMessage](wrapped2) + c.Assert(ok, gc.Equals, true) + c.Assert(cem.ComplexMessage(), gc.Equals, complexErr.Message) + + ceo, ok := errors.AsType[*complexErrorOther](wrapped2) + c.Assert(ok, gc.Equals, false) + c.Assert(ceo, gc.Equals, (*complexErrorOther)(nil)) + + ceo, ok = errors.AsType[*complexErrorOther](nil) + c.Assert(ok, gc.Equals, false) + c.Assert(ceo, gc.Equals, (*complexErrorOther)(nil)) + + complexErrOther := &complexErrorOther{Message: "another complex error"} + ce, ok = errors.AsType[*complexError](complexErrOther) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErrOther.Message) + + wrapped2 = fmt.Errorf("wrapping1: %w", complexErrOther) + ce, ok = errors.AsType[*complexError](wrapped2) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErrOther.Message) } func ExampleHide() { @@ -464,12 +522,24 @@ func (m *MyError) Error() string { return m.Message } -func ExampleIsType() { +func ExampleHasType() { + myErr := &MyError{Message: "these are not the droids you're looking for"} + err := fmt.Errorf("wrapped: %w", myErr) + is := errors.HasType[*MyError](err) + fmt.Println(is) + + // Output: + // true +} + +func ExampleAsType() { myErr := &MyError{Message: "these are not the droids you're looking for"} err := fmt.Errorf("wrapped: %w", myErr) - is := errors.IsType[*MyError](err) + myErr, is := errors.AsType[*MyError](err) fmt.Println(is) + fmt.Println(myErr.Message) // Output: // true + // these are not the droids you're looking for }