From d52126ca69ee1bb5ef0fbd94bd058d9f3a84e48b Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Wed, 30 Aug 2023 23:30:06 -0700 Subject: [PATCH] Add cmpopts.EquateComparable This helper function makes it easier to specify that comparable types are safe to directly compare with the == operator in Go. The API does not use generics as it follows existing options like cmp.AllowUnexported, cmpopts.IgnoreUnexported, or cmpopts.IgnoreTypes. While generics provides type safety, the user experience is not as nice. Our current API allows multiple types to be specified: cmpopts.EquateComparable(netip.Addr{}, netip.Prefix{}) While generics would not allow variadic arguments: cmpopts.EquateComparable[netip.Addr]() cmpopts.EquateComparable[netip.Prefix]() Fixes #339 --- cmp/cmpopts/equate.go | 29 +++++++++++++++++++++++++++++ cmp/cmpopts/util_test.go | 31 +++++++++++++++++++++++++++++++ cmp/options.go | 4 +++- 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/cmp/cmpopts/equate.go b/cmp/cmpopts/equate.go index 90974e6..3d8d0cd 100644 --- a/cmp/cmpopts/equate.go +++ b/cmp/cmpopts/equate.go @@ -7,6 +7,7 @@ package cmpopts import ( "errors" + "fmt" "math" "reflect" "time" @@ -154,3 +155,31 @@ func compareErrors(x, y interface{}) bool { ye := y.(error) return errors.Is(xe, ye) || errors.Is(ye, xe) } + +// EquateComparable returns a [cmp.Option] that determines equality +// of comparable types by directly comparing them using the == operator in Go. +// The types to compare are specified by passing a value of that type. +// This option should only be used on types that are documented as being +// safe for direct == comparison. For example, [net/netip.Addr] is documented +// as being semantically safe to use with ==, while [time.Time] is documented +// to discourage the use of == on time values. +func EquateComparable(typs ...interface{}) cmp.Option { + types := make(typesFilter) + for _, typ := range typs { + switch t := reflect.TypeOf(typ); { + case !t.Comparable(): + panic(fmt.Sprintf("%T is not a comparable Go type", typ)) + case types[t]: + panic(fmt.Sprintf("%T is already specified", typ)) + default: + types[t] = true + } + } + return cmp.FilterPath(types.filter, cmp.Comparer(equateAny)) +} + +type typesFilter map[reflect.Type]bool + +func (tf typesFilter) filter(p cmp.Path) bool { return tf[p.Last().Type()] } + +func equateAny(x, y interface{}) bool { return x == y } diff --git a/cmp/cmpopts/util_test.go b/cmp/cmpopts/util_test.go index 7adeb9b..6a7c300 100644 --- a/cmp/cmpopts/util_test.go +++ b/cmp/cmpopts/util_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "math" + "net/netip" "reflect" "strings" "sync" @@ -676,6 +677,36 @@ func TestOptions(t *testing.T) { opts: []cmp.Option{EquateErrors()}, wantEqual: false, reason: "AnyError is not equal to nil value", + }, { + label: "EquateComparable", + x: []struct{ P netip.Addr }{ + {netip.AddrFrom4([4]byte{1, 2, 3, 4})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 5})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 6})}, + }, + y: []struct{ P netip.Addr }{ + {netip.AddrFrom4([4]byte{1, 2, 3, 4})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 5})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 6})}, + }, + opts: []cmp.Option{EquateComparable(netip.Addr{})}, + wantEqual: true, + reason: "equal because all IP addresses are the same", + }, { + label: "EquateComparable", + x: []struct{ P netip.Addr }{ + {netip.AddrFrom4([4]byte{1, 2, 3, 4})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 5})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 6})}, + }, + y: []struct{ P netip.Addr }{ + {netip.AddrFrom4([4]byte{1, 2, 3, 4})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 7})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 6})}, + }, + opts: []cmp.Option{EquateComparable(netip.Addr{})}, + wantEqual: false, + reason: "not equal because second IP address is different", }, { label: "IgnoreFields", x: Bar1{Foo3{&Foo2{&Foo1{Alpha: 5}}}}, diff --git a/cmp/options.go b/cmp/options.go index 518b6ac..392a1ce 100644 --- a/cmp/options.go +++ b/cmp/options.go @@ -232,7 +232,9 @@ func (validator) apply(s *state, vx, vy reflect.Value) { if t := s.curPath.Index(-2).Type(); t.Name() != "" { // Named type with unexported fields. name = fmt.Sprintf("%q.%v", t.PkgPath(), t.Name()) // e.g., "path/to/package".MyType - if _, ok := reflect.New(t).Interface().(error); ok { + if t.Comparable() { + help = "consider using cmpopts.EquateComparable to compare comparable Go types" + } else if _, ok := reflect.New(t).Interface().(error); ok { help = "consider using cmpopts.EquateErrors to compare error values" } } else {