From 60b30aa08b576f6e92e8bb03a23e0aea39d8fb04 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Wed, 23 Oct 2024 09:29:25 -0700 Subject: [PATCH] Support compare functions with SortSlices and SortMaps The SortSlices and SortMaps options predate generics and accept an interface{}, so it is possible with reflection to support other function signatures than "func(T, T) bool". In particular, the Go ecosystem is increasingly moving towards "func(T, T) int" as the signature for ordering as evidenced by the newer slices.SortFunc function in stdlib. Thus, modernize cmpopts by supporting "func(T, T) int". Also, bump the minimum version to Go 1.21 to match the minimum supported version of google.golang.org/protobuf. Fixes #365 --- .github/workflows/test.yml | 2 +- cmp/cmpopts/sort.go | 64 ++++++++++++++++++++++++----------- cmp/cmpopts/util_test.go | 48 ++++++++++++++++++-------- cmp/internal/function/func.go | 7 ++++ go.mod | 2 +- 5 files changed, 87 insertions(+), 36 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e21ebfa..712196a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,7 +6,7 @@ jobs: test: strategy: matrix: - go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x] + go-version: [1.21.x] os: [ubuntu-latest, macos-latest] runs-on: ${{ matrix.os }} steps: diff --git a/cmp/cmpopts/sort.go b/cmp/cmpopts/sort.go index c6d09da..4a8a2a6 100644 --- a/cmp/cmpopts/sort.go +++ b/cmp/cmpopts/sort.go @@ -14,22 +14,29 @@ import ( ) // SortSlices returns a [cmp.Transformer] option that sorts all []V. -// The less function must be of the form "func(T, T) bool" which is used to -// sort any slice with element type V that is assignable to T. +// The lessOrCompareFunc function must be either +// a less function of the form "func(T, T) bool" or +// a compare function of the format "func(T, T) int" +// which is used to sort any slice with element type V that is assignable to T. // -// The less function must be: +// A less function must be: // - Deterministic: less(x, y) == less(x, y) // - Irreflexive: !less(x, x) // - Transitive: if !less(x, y) and !less(y, z), then !less(x, z) // -// The less function does not have to be "total". That is, if !less(x, y) and -// !less(y, x) for two elements x and y, their relative order is maintained. +// A compare function must be: +// - Deterministic: compare(x, y) == compare(x, y) +// - Irreflexive: compare(x, x) == 0 +// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z) +// +// The function does not have to be "total". That is, if x != y, but +// less reports false or compare reports 0, their relative order is maintained. // // SortSlices can be used in conjunction with [EquateEmpty]. -func SortSlices(lessFunc interface{}) cmp.Option { - vf := reflect.ValueOf(lessFunc) - if !function.IsType(vf.Type(), function.Less) || vf.IsNil() { - panic(fmt.Sprintf("invalid less function: %T", lessFunc)) +func SortSlices(lessOrCompareFunc interface{}) cmp.Option { + vf := reflect.ValueOf(lessOrCompareFunc) + if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() { + panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc)) } ss := sliceSorter{vf.Type().In(0), vf} return cmp.FilterValues(ss.filter, cmp.Transformer("cmpopts.SortSlices", ss.sort)) @@ -79,28 +86,40 @@ func (ss sliceSorter) checkSort(v reflect.Value) { } func (ss sliceSorter) less(v reflect.Value, i, j int) bool { vx, vy := v.Index(i), v.Index(j) - return ss.fnc.Call([]reflect.Value{vx, vy})[0].Bool() + vo := ss.fnc.Call([]reflect.Value{vx, vy})[0] + if vo.Kind() == reflect.Bool { + return vo.Bool() + } else { + return vo.Int() < 0 + } } -// SortMaps returns a [cmp.Transformer] option that flattens map[K]V types to be a -// sorted []struct{K, V}. The less function must be of the form -// "func(T, T) bool" which is used to sort any map with key K that is -// assignable to T. +// SortMaps returns a [cmp.Transformer] option that flattens map[K]V types to be +// a sorted []struct{K, V}. The lessOrCompareFunc function must be either +// a less function of the form "func(T, T) bool" or +// a compare function of the format "func(T, T) int" +// which is used to sort any map with key K that is assignable to T. // // Flattening the map into a slice has the property that [cmp.Equal] is able to // use [cmp.Comparer] options on K or the K.Equal method if it exists. // -// The less function must be: +// A less function must be: // - Deterministic: less(x, y) == less(x, y) // - Irreflexive: !less(x, x) // - Transitive: if !less(x, y) and !less(y, z), then !less(x, z) // - Total: if x != y, then either less(x, y) or less(y, x) // +// A compare function must be: +// - Deterministic: compare(x, y) == compare(x, y) +// - Irreflexive: compare(x, x) == 0 +// - Transitive: if compare(x, y) < 0 and compare(y, z) < 0, then compare(x, z) < 0 +// - Total: if x != y, then compare(x, y) != 0 +// // SortMaps can be used in conjunction with [EquateEmpty]. -func SortMaps(lessFunc interface{}) cmp.Option { - vf := reflect.ValueOf(lessFunc) - if !function.IsType(vf.Type(), function.Less) || vf.IsNil() { - panic(fmt.Sprintf("invalid less function: %T", lessFunc)) +func SortMaps(lessOrCompareFunc interface{}) cmp.Option { + vf := reflect.ValueOf(lessOrCompareFunc) + if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() { + panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc)) } ms := mapSorter{vf.Type().In(0), vf} return cmp.FilterValues(ms.filter, cmp.Transformer("cmpopts.SortMaps", ms.sort)) @@ -143,5 +162,10 @@ func (ms mapSorter) checkSort(v reflect.Value) { } func (ms mapSorter) less(v reflect.Value, i, j int) bool { vx, vy := v.Index(i).Field(0), v.Index(j).Field(0) - return ms.fnc.Call([]reflect.Value{vx, vy})[0].Bool() + vo := ms.fnc.Call([]reflect.Value{vx, vy})[0] + if vo.Kind() == reflect.Bool { + return vo.Bool() + } else { + return vo.Int() < 0 + } } diff --git a/cmp/cmpopts/util_test.go b/cmp/cmpopts/util_test.go index 6a7c300..8c74d65 100644 --- a/cmp/cmpopts/util_test.go +++ b/cmp/cmpopts/util_test.go @@ -130,6 +130,23 @@ func TestOptions(t *testing.T) { opts: []cmp.Option{SortSlices(func(x, y int) bool { return x < y })}, wantEqual: true, reason: "equal because SortSlices sorts the slices", + }, { + label: "SortSlices", + x: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + y: []int{1, 0, 5, 2, 8, 9, 4, 3, 6, 7}, + opts: []cmp.Option{SortSlices(func(x, y int) int { + // TODO(Go1.22): Use cmp.Compare. + switch { + case x < y: + return -1 + case y > x: + return +1 + default: + return 0 + } + })}, + wantEqual: true, + reason: "equal because SortSlices sorts the slices", }, { label: "SortSlices", x: []MyInt{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, @@ -201,6 +218,21 @@ func TestOptions(t *testing.T) { opts: []cmp.Option{SortMaps(func(x, y time.Time) bool { return x.Before(y) })}, wantEqual: true, reason: "equal because SortMaps flattens to a slice where Time.Equal can be used", + }, { + label: "SortMaps", + x: map[time.Time]string{ + time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC): "0th birthday", + time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC): "1st birthday", + time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC): "2nd birthday", + }, + y: map[time.Time]string{ + time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "0th birthday", + time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "1st birthday", + time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "2nd birthday", + }, + opts: []cmp.Option{SortMaps(func(x, y time.Time) int { return time.Time.Compare(x, y) })}, + wantEqual: true, + reason: "equal because SortMaps flattens to a slice where Time.Equal can be used", }, { label: "SortMaps", x: map[MyTime]string{ @@ -1184,29 +1216,17 @@ func TestPanic(t *testing.T) { args: args(time.Duration(-1)), wantPanic: "margin must be a non-negative number", reason: "negative duration is invalid", - }, { - label: "SortSlices", - fnc: SortSlices, - args: args(strings.Compare), - wantPanic: "invalid less function", - reason: "func(x, y string) int is wrong signature for less", }, { label: "SortSlices", fnc: SortSlices, args: args((func(_, _ int) bool)(nil)), - wantPanic: "invalid less function", + wantPanic: "invalid less or compare function", reason: "nil value is not valid", - }, { - label: "SortMaps", - fnc: SortMaps, - args: args(strings.Compare), - wantPanic: "invalid less function", - reason: "func(x, y string) int is wrong signature for less", }, { label: "SortMaps", fnc: SortMaps, args: args((func(_, _ int) bool)(nil)), - wantPanic: "invalid less function", + wantPanic: "invalid less or compare function", reason: "nil value is not valid", }, { label: "IgnoreFields", diff --git a/cmp/internal/function/func.go b/cmp/internal/function/func.go index d127d43..def01a6 100644 --- a/cmp/internal/function/func.go +++ b/cmp/internal/function/func.go @@ -19,6 +19,7 @@ const ( tbFunc // func(T) bool ttbFunc // func(T, T) bool + ttiFunc // func(T, T) int trbFunc // func(T, R) bool tibFunc // func(T, I) bool trFunc // func(T) R @@ -28,11 +29,13 @@ const ( Transformer = trFunc // func(T) R ValueFilter = ttbFunc // func(T, T) bool Less = ttbFunc // func(T, T) bool + Compare = ttiFunc // func(T, T) int ValuePredicate = tbFunc // func(T) bool KeyValuePredicate = trbFunc // func(T, R) bool ) var boolType = reflect.TypeOf(true) +var intType = reflect.TypeOf(0) // IsType reports whether the reflect.Type is of the specified function type. func IsType(t reflect.Type, ft funcType) bool { @@ -49,6 +52,10 @@ func IsType(t reflect.Type, ft funcType) bool { if ni == 2 && no == 1 && t.In(0) == t.In(1) && t.Out(0) == boolType { return true } + case ttiFunc: // func(T, T) int + if ni == 2 && no == 1 && t.In(0) == t.In(1) && t.Out(0) == intType { + return true + } case trbFunc: // func(T, R) bool if ni == 2 && no == 1 && t.Out(0) == boolType { return true diff --git a/go.mod b/go.mod index f55cea6..e183d48 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/google/go-cmp -go 1.13 +go 1.21