From a4ce23d64799bc21f45a0587b06fb400e8db2183 Mon Sep 17 00:00:00 2001 From: Thomas Miller Date: Wed, 22 Jun 2022 12:05:50 +1000 Subject: [PATCH] Adds AsType & HasType and removes IsType. The original implementation of IsType was wrong and asserted assign ability over comparability. Now we have an AsType that works the same As but creates it's own target with generics and a new HasType that disregards the target return of AsType. --- functions.go | 61 +++++++++++++++++++++------ functions_test.go | 104 ++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 136 insertions(+), 29 deletions(-) diff --git a/functions.go b/functions.go index 4823ab57..952a6739 100644 --- a/functions.go +++ b/functions.go @@ -351,18 +351,17 @@ func Is(err, target error) bool { return stderrors.Is(err, target) } -// IsType is a convenience method for ascertaining if an error contains the -// target error type within its chain. This is aimed at ease of development -// where a more complicated error type wants to be to checked for existence but -// pointer var of that type is too much overhead. -func IsType[t error](err error) bool { - for err != nil { - if _, is := err.(t); is { - return true - } - err = stderrors.Unwrap(err) - } - return false +// HasType is a function wrapper around AsType dropping the where return value +// from AsType() making a function that can be used like this: +// +// return HasType[*MyError](err) +// +// Or +// +// if HasType[*MyError](err) {} +func HasType[T error](err error) bool { + _, rval := AsType[T](err) + return rval } // As is a proxy for the As function in Go's standard `errors` library @@ -371,6 +370,44 @@ func As(err error, target interface{}) bool { return stderrors.As(err, target) } +// AsType is a convenience method for checking and getting an error from within +// a chain that is of type T. If no error is found of type T in the chain the +// zero value of T is returned with false. If an error in the chain implementes +// As(any) bool then it's As method will be called if it's type is not of type T. + +// AsType finds the first error in err's chain that is assignable to type T, and +// if a match is found, returns that error value and true. Otherwise, it returns +// T's zero value and false. +// +// AsType is equivalent to errors.As, but uses a type parameter and returns +// the target, to avoid having to define a variable before the call. For +// example, callers can replace this: +// +// var pathError *fs.PathError +// if errors.As(err, &pathError) { +// fmt.Println("Failed at path:", pathError.Path) +// } +// +// With: +// +// if pathError, ok := errors.AsType[*fs.PathError](err); ok { +// fmt.Println("Failed at path:", pathError.Path) +// } +func AsType[T error](err error) (T, bool) { + for err != nil { + if e, is := err.(T); is { + return e, true + } + var res T + if x, ok := err.(interface{ As(any) bool }); ok && x.As(&res) { + return res, true + } + err = stderrors.Unwrap(err) + } + var zero T + return zero, false +} + // SetLocation takes a given error and records where in the stack SetLocation // was called from and returns the wrapped error with the location information // set. The returned error implements the Locationer interface. If err is nil diff --git a/functions_test.go b/functions_test.go index cd9502e0..9f1a11aa 100644 --- a/functions_test.go +++ b/functions_test.go @@ -403,11 +403,9 @@ func (*functionSuite) TestQuietWrappedErrorStillSatisfied(c *gc.C) { c.Assert(errors.Is(err, simpleTestError), gc.Equals, true) } -type FooError struct { -} - -func (*FooError) Error() string { - return "I am here boss" +type ComplexErrorMessage interface { + error + ComplexMessage() string } type complexError struct { @@ -418,30 +416,90 @@ func (c *complexError) Error() string { return c.Message } +func (c *complexError) ComplexMessage() string { + return c.Message +} + type complexErrorOther struct { Message string } +func (c *complexErrorOther) As(e any) bool { + if ce, ok := e.(**complexError); ok { + *ce = &complexError{ + Message: c.Message, + } + return true + } + return false +} + func (c *complexErrorOther) Error() string { return c.Message } -func (*functionSuite) TestIsType(c *gc.C) { +func (c *complexErrorOther) ComplexMessage() string { + return c.Message +} + +func (*functionSuite) TestHasType(c *gc.C) { complexErr := &complexError{Message: "complex error message"} wrapped1 := fmt.Errorf("wrapping1: %w", complexErr) wrapped2 := fmt.Errorf("wrapping2: %w", wrapped1) - c.Assert(errors.IsType[*complexError](complexErr), gc.Equals, true) - c.Assert(errors.IsType[*complexError](wrapped1), gc.Equals, true) - c.Assert(errors.IsType[*complexError](wrapped2), gc.Equals, true) - c.Assert(errors.IsType[*complexErrorOther](complexErr), gc.Equals, false) - c.Assert(errors.IsType[*complexErrorOther](wrapped1), gc.Equals, false) - c.Assert(errors.IsType[*complexErrorOther](wrapped2), gc.Equals, false) + c.Assert(errors.HasType[*complexError](complexErr), gc.Equals, true) + c.Assert(errors.HasType[*complexError](wrapped1), gc.Equals, true) + c.Assert(errors.HasType[*complexError](wrapped2), gc.Equals, true) + c.Assert(errors.HasType[ComplexErrorMessage](wrapped2), gc.Equals, true) + c.Assert(errors.HasType[*complexErrorOther](wrapped2), gc.Equals, false) + c.Assert(errors.HasType[*complexErrorOther](nil), gc.Equals, false) - err := errors.New("test") - c.Assert(errors.IsType[*complexErrorOther](err), gc.Equals, false) + complexErrOther := &complexErrorOther{Message: "another complex error"} + + c.Assert(errors.HasType[*complexError](complexErrOther), gc.Equals, true) - c.Assert(errors.IsType[*complexErrorOther](nil), gc.Equals, false) + wrapped2 = fmt.Errorf("wrapping1: %w", complexErrOther) + c.Assert(errors.HasType[*complexError](wrapped2), gc.Equals, true) +} + +func (*functionSuite) TestAsType(c *gc.C) { + complexErr := &complexError{Message: "complex error message"} + wrapped1 := fmt.Errorf("wrapping1: %w", complexErr) + wrapped2 := fmt.Errorf("wrapping2: %w", wrapped1) + + ce, ok := errors.AsType[*complexError](complexErr) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErr.Message) + + ce, ok = errors.AsType[*complexError](wrapped1) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErr.Message) + + ce, ok = errors.AsType[*complexError](wrapped2) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErr.Message) + + cem, ok := errors.AsType[ComplexErrorMessage](wrapped2) + c.Assert(ok, gc.Equals, true) + c.Assert(cem.ComplexMessage(), gc.Equals, complexErr.Message) + + ceo, ok := errors.AsType[*complexErrorOther](wrapped2) + c.Assert(ok, gc.Equals, false) + c.Assert(ceo, gc.Equals, (*complexErrorOther)(nil)) + + ceo, ok = errors.AsType[*complexErrorOther](nil) + c.Assert(ok, gc.Equals, false) + c.Assert(ceo, gc.Equals, (*complexErrorOther)(nil)) + + complexErrOther := &complexErrorOther{Message: "another complex error"} + ce, ok = errors.AsType[*complexError](complexErrOther) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErrOther.Message) + + wrapped2 = fmt.Errorf("wrapping1: %w", complexErrOther) + ce, ok = errors.AsType[*complexError](wrapped2) + c.Assert(ok, gc.Equals, true) + c.Assert(ce.Message, gc.Equals, complexErrOther.Message) } func ExampleHide() { @@ -464,12 +522,24 @@ func (m *MyError) Error() string { return m.Message } -func ExampleIsType() { +func ExampleHasType() { + myErr := &MyError{Message: "these are not the droids you're looking for"} + err := fmt.Errorf("wrapped: %w", myErr) + is := errors.HasType[*MyError](err) + fmt.Println(is) + + // Output: + // true +} + +func ExampleAsType() { myErr := &MyError{Message: "these are not the droids you're looking for"} err := fmt.Errorf("wrapped: %w", myErr) - is := errors.IsType[*MyError](err) + myErr, is := errors.AsType[*MyError](err) fmt.Println(is) + fmt.Println(myErr.Message) // Output: // true + // these are not the droids you're looking for }