From 23338b1e2922079f1c8a9bc3efd8de8adcd8c2ac Mon Sep 17 00:00:00 2001 From: Jeremy Quirke Date: Mon, 6 Mar 2023 22:04:58 -0800 Subject: [PATCH] Support map value groups This revision allows dig to specify value groups of map type. For example: ``` type Params struct { dig.In Things []int `group:"foogroup"` MapOfThings map[string]int `group:"foogroup"` } type Result struct { dig.Out Int1 int `name:"foo1" group:"foogroup"` Int2 int `name:"foo2" group:"foogroup"` Int3 int `name:"foo3" group:"foogroup"` } c.Provide(func() Result { return Result{Int1: 1, Int2: 2, Int3: 3} }) c.Invoke(func(p Params) { }) ``` p.Things will be a value group slice as per usual, containing the elements {1,2,3} in an arbitrary order. p.MapOfThings will be a key-value pairing of {"foo1":1, "foo2":2, "foo3":3}. --- decorate.go | 4 +- decorate_test.go | 99 ++++++++++++++++++++++++++++++ dig_test.go | 155 ++++++++++++++++++++++++++++++++++++++++++++++- graph.go | 4 +- param.go | 69 ++++++++++++++------- param_test.go | 18 ++++-- result.go | 2 +- 7 files changed, 319 insertions(+), 32 deletions(-) diff --git a/decorate.go b/decorate.go index 3a7114c5..c57e817a 100644 --- a/decorate.go +++ b/decorate.go @@ -282,7 +282,9 @@ func findResultKeys(r resultList) ([]key, error) { case resultSingle: keys = append(keys, key{t: innerResult.Type, name: innerResult.Name}) case resultGrouped: - if innerResult.Type.Kind() != reflect.Slice { + isMap := innerResult.Type.Kind() == reflect.Map && innerResult.Type.Key().Kind() == reflect.String + isSlice := innerResult.Type.Kind() == reflect.Slice + if !isMap && !isSlice { return nil, newErrInvalidInput("decorating a value group requires decorating the entire value group, not a single value", nil) } keys = append(keys, key{t: innerResult.Type.Elem(), group: innerResult.Group}) diff --git a/decorate_test.go b/decorate_test.go index 721ba801..2bec237c 100644 --- a/decorate_test.go +++ b/decorate_test.go @@ -215,6 +215,64 @@ func TestDecorateSuccess(t *testing.T) { })) }) + t.Run("map is treated as an ordinary dependency without group tag, named or unnamed, and passes through multiple scopes", func(t *testing.T) { + type params struct { + dig.In + + Strings1 map[string]string + Strings2 map[string]string `name:"strings2"` + } + + type childResult struct { + dig.Out + + Strings1 map[string]string + Strings2 map[string]string `name:"strings2"` + } + + type A map[string]string + type B map[string]string + + parent := digtest.New(t) + parent.RequireProvide(func() map[string]string { return map[string]string{"key1": "val1", "key2": "val2"} }) + parent.RequireProvide(func() map[string]string { return map[string]string{"key1": "val21", "key2": "val22"} }, dig.Name("strings2")) + + parent.RequireProvide(func(p params) A { return A(p.Strings1) }) + parent.RequireProvide(func(p params) B { return B(p.Strings2) }) + + child := parent.Scope("child") + + parent.RequireDecorate(func(p params) childResult { + res := childResult{Strings1: make(map[string]string, len(p.Strings1))} + for k, s := range p.Strings1 { + res.Strings1[k] = strings.ToUpper(s) + } + res.Strings2 = p.Strings2 + return res + }) + + child.RequireDecorate(func(p params) childResult { + res := childResult{Strings2: make(map[string]string, len(p.Strings2))} + for k, s := range p.Strings2 { + res.Strings2[k] = strings.ToUpper(s) + } + res.Strings1 = p.Strings1 + res.Strings1["key3"] = "newval" + return res + }) + + require.NoError(t, child.Invoke(func(p params) { + require.Len(t, p.Strings1, 3) + assert.Equal(t, "VAL1", p.Strings1["key1"]) + assert.Equal(t, "VAL2", p.Strings1["key2"]) + assert.Equal(t, "newval", p.Strings1["key3"]) + require.Len(t, p.Strings2, 2) + assert.Equal(t, "VAL21", p.Strings2["key1"]) + assert.Equal(t, "VAL22", p.Strings2["key2"]) + + })) + + }) t.Run("decorate values in soft group", func(t *testing.T) { type params struct { dig.In @@ -393,6 +451,46 @@ func TestDecorateSuccess(t *testing.T) { assert.Equal(t, `[]string[group = "animals"]`, info.Inputs[0].String()) }) + t.Run("decorate with map value groups", func(t *testing.T) { + type Params struct { + dig.In + + Animals map[string]string `group:"animals"` + } + + type Result struct { + dig.Out + + Animals map[string]string `group:"animals"` + } + + c := digtest.New(t) + c.RequireProvide(func() string { return "dog" }, dig.Name("animal1"), dig.Group("animals")) + c.RequireProvide(func() string { return "cat" }, dig.Name("animal2"), dig.Group("animals")) + c.RequireProvide(func() string { return "gopher" }, dig.Name("animal3"), dig.Group("animals")) + + var info dig.DecorateInfo + c.RequireDecorate(func(p Params) Result { + animals := p.Animals + for k, v := range animals { + animals[k] = "good " + v + } + return Result{ + Animals: animals, + } + }, dig.FillDecorateInfo(&info)) + + c.RequireInvoke(func(p Params) { + assert.Len(t, p.Animals, 3) + assert.Equal(t, "good dog", p.Animals["animal1"]) + assert.Equal(t, "good cat", p.Animals["animal2"]) + assert.Equal(t, "good gopher", p.Animals["animal3"]) + }) + + require.Equal(t, 1, len(info.Inputs)) + assert.Equal(t, `map[string]string[group = "animals"]`, info.Inputs[0].String()) + }) + t.Run("decorate with optional parameter", func(t *testing.T) { c := digtest.New(t) @@ -918,6 +1016,7 @@ func TestMultipleDecorates(t *testing.T) { assert.ElementsMatch(t, []int{2, 3, 4}, a.Values) }) }) + } func TestFillDecorateInfoString(t *testing.T) { diff --git a/dig_test.go b/dig_test.go index 72381c43..adab0cfc 100644 --- a/dig_test.go +++ b/dig_test.go @@ -1241,6 +1241,27 @@ func TestGroups(t *testing.T) { }) }) + t.Run("provide multiple with the same name and group but different type", func(t *testing.T) { + c := digtest.New(t) + type A struct{} + type B struct{} + type ret1 struct { + dig.Out + *A `name:"foo" group:"foos"` + } + type ret2 struct { + dig.Out + *B `name:"foo" group:"foos"` + } + c.RequireProvide(func() ret1 { + return ret1{A: &A{}} + }) + + c.RequireProvide(func() ret2 { + return ret2{B: &B{}} + }) + }) + t.Run("different types may be grouped", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) @@ -1745,6 +1766,118 @@ func TestGroups(t *testing.T) { assert.ElementsMatch(t, []string{"a"}, param.Value) }) }) + /* map tests */ + t.Run("empty map received without provides", func(t *testing.T) { + c := digtest.New(t) + + type in struct { + dig.In + + Values map[string]int `group:"foo"` + } + + c.RequireInvoke(func(i in) { + require.Empty(t, i.Values) + }) + }) + + t.Run("map value group using dig.Name and dig.Group", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + c.RequireProvide(func() int { + return 1 + }, dig.Name("value1"), dig.Group("val")) + c.RequireProvide(func() int { + return 2 + }, dig.Name("value2"), dig.Group("val")) + c.RequireProvide(func() int { + return 3 + }, dig.Name("value3"), dig.Group("val")) + + type in struct { + dig.In + + Value1 int `name:"value1"` + Value2 int `name:"value2"` + Value3 int `name:"value3"` + Values []int `group:"val"` + ValueMap map[string]int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + assert.Equal(t, i.ValueMap["value1"], 1) + assert.Equal(t, i.ValueMap["value2"], 2) + assert.Equal(t, i.ValueMap["value3"], 3) + assert.Equal(t, i.Value1, 1) + assert.Equal(t, i.Value2, 2) + assert.Equal(t, i.Value3, 3) + }) + }) + t.Run("values are provided, map and name and slice", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + type out struct { + dig.Out + + Value1 int `name:"value1" group:"val"` + Value2 int `name:"value2" group:"val"` + Value3 int `name:"value3" group:"val"` + } + + c.RequireProvide(func() out { + return out{Value1: 1, Value2: 2, Value3: 3} + }) + + type in struct { + dig.In + + Value1 int `name:"value1"` + Value2 int `name:"value2"` + Value3 int `name:"value3"` + Values []int `group:"val"` + ValueMap map[string]int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + assert.Equal(t, i.ValueMap["value1"], 1) + assert.Equal(t, i.ValueMap["value2"], 2) + assert.Equal(t, i.ValueMap["value3"], 3) + assert.Equal(t, i.Value1, 1) + assert.Equal(t, i.Value2, 2) + assert.Equal(t, i.Value3, 3) + }) + }) + + t.Run("Every item used in a map must have a named key", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out struct { + dig.Out + + Value1 int `name:"value1" group:"val"` + Value2 int `name:"value2" group:"val"` + Value3 int `group:"val"` + } + + c.RequireProvide(func() out { + return out{Value1: 1, Value2: 2, Value3: 3} + }) + + type in struct { + dig.In + + ValueMap map[string]int `group:"val"` + } + var called = false + err := c.Invoke(func(i in) { called = true }) + dig.AssertErrorMatches(t, err, + `could not build arguments for function "go.uber.org/dig_test".TestGroups\S+`, + `dig_test.go:\d+`, // file:line + `every entry in a map value groups must have a name, group "val" is missing a name`) + assert.False(t, called, "shouldn't call invoked function when deps aren't available") + }) + } // --- END OF END TO END TESTS @@ -2753,7 +2886,27 @@ func testProvideFailures(t *testing.T, dryRun bool) { ) }) - t.Run("provide multiple instances with the same name but different group", func(t *testing.T) { + t.Run("provide multiple instances with the same name and same group using options", func(t *testing.T) { + c := digtest.New(t, dig.DryRun(dryRun)) + type A struct{} + + c.RequireProvide(func() *A { + return &A{} + }, dig.Group("foos"), dig.Name("foo")) + + err := c.Provide(func() *A { + return &A{} + }, dig.Group("foos"), dig.Name("foo")) + require.Error(t, err, "expected error on the second provide") + dig.AssertErrorMatches(t, err, + `cannot provide function "go.uber.org/dig_test".testProvideFailures\S+`, + `dig_test.go:\d+`, // file:line + `cannot provide \*dig_test.A\[name="foo"\] from \[1\]:`, + `already provided by "go.uber.org/dig_test".testProvideFailures\S+`, + ) + }) + + t.Run("provide multiple instances with the same name and type but different group", func(t *testing.T) { c := digtest.New(t, dig.DryRun(dryRun)) type A struct{} type ret1 struct { diff --git a/graph.go b/graph.go index e08f1f54..c52e78f7 100644 --- a/graph.go +++ b/graph.go @@ -28,7 +28,7 @@ type graphNode struct { } // graphHolder is the dependency graph of the container. -// It saves constructorNodes and paramGroupedSlice (value groups) +// It saves constructorNodes and paramGroupedCollection (value groups) // as nodes in the graph. // It implements the graph interface defined by internal/graph. // It has 1-1 correspondence with the Scope whose graph it represents. @@ -68,7 +68,7 @@ func (gh *graphHolder) EdgesFrom(u int) []int { for _, param := range w.paramList.Params { orders = append(orders, getParamOrder(gh, param)...) } - case *paramGroupedSlice: + case *paramGroupedCollection: providers := gh.s.getAllGroupProviders(w.Group, w.Type.Elem()) for _, provider := range providers { orders = append(orders, provider.Order(gh.s)) diff --git a/param.go b/param.go index e60501e1..e44b3b5c 100644 --- a/param.go +++ b/param.go @@ -38,10 +38,12 @@ import ( // paramSingle An explicitly requested type. // paramObject dig.In struct where each field in the struct can be another // param. -// paramGroupedSlice -// A slice consuming a value group. This will receive all +// paramGroupedCollection +// A slice or map consuming a value group. This will receive all // values produced with a `group:".."` tag with the same name -// as a slice. +// as a slice or map. For a map, every value produced with the +// same group name MUST have a name which will form the map key. + type param interface { fmt.Stringer @@ -59,7 +61,7 @@ var ( _ param = paramSingle{} _ param = paramObject{} _ param = paramList{} - _ param = paramGroupedSlice{} + _ param = paramGroupedCollection{} ) // newParam builds a param from the given type. If the provided type is a @@ -342,7 +344,7 @@ func getParamOrder(gh *graphHolder, param param) []int { for _, provider := range providers { orders = append(orders, provider.Order(gh.s)) } - case paramGroupedSlice: + case paramGroupedCollection: // value group parameters have nodes of their own. // We can directly return that here. orders = append(orders, p.orders[gh.s]) @@ -401,7 +403,7 @@ func (po paramObject) Build(c containerStore) (reflect.Value, error) { var softGroupsQueue []paramObjectField var fields []paramObjectField for _, f := range po.Fields { - if p, ok := f.Param.(paramGroupedSlice); ok && p.Soft { + if p, ok := f.Param.(paramGroupedCollection); ok && p.Soft { softGroupsQueue = append(softGroupsQueue, f) continue } @@ -451,7 +453,7 @@ func newParamObjectField(idx int, f reflect.StructField, c containerStore) (para case f.Tag.Get(_groupTag) != "": var err error - p, err = newParamGroupedSlice(f, c) + p, err = newParamGroupedCollection(f, c) if err != nil { return pof, err } @@ -488,13 +490,13 @@ func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) { return v, nil } -// paramGroupedSlice is a param which produces a slice of values with the same +// paramGroupedCollection is a param which produces a slice or map of values with the same // group name. -type paramGroupedSlice struct { +type paramGroupedCollection struct { // Name of the group as specified in the `group:".."` tag. Group string - // Type of the slice. + // Type of the map or slice. Type reflect.Type // Soft is used to denote a soft dependency between this param and its @@ -502,15 +504,17 @@ type paramGroupedSlice struct { // provide another value requested in the graph Soft bool + isMap bool orders map[*Scope]int } -func (pt paramGroupedSlice) String() string { +func (pt paramGroupedCollection) String() string { // io.Reader[group="foo"] refers to a group of io.Readers called 'foo' return fmt.Sprintf("%v[group=%q]", pt.Type.Elem(), pt.Group) + // JQTODO, different string for map } -func (pt paramGroupedSlice) DotParam() []*dot.Param { +func (pt paramGroupedCollection) DotParam() []*dot.Param { return []*dot.Param{ { Node: &dot.Node{ @@ -521,18 +525,21 @@ func (pt paramGroupedSlice) DotParam() []*dot.Param { } } -// newParamGroupedSlice builds a paramGroupedSlice from the provided type with +// newParamGroupedCollection builds a paramGroupedCollection from the provided type with // the given name. // -// The type MUST be a slice type. -func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGroupedSlice, error) { +// The type MUST be a slice or map[string]T type. +func newParamGroupedCollection(f reflect.StructField, c containerStore) (paramGroupedCollection, error) { g, err := parseGroupString(f.Tag.Get(_groupTag)) if err != nil { - return paramGroupedSlice{}, err + return paramGroupedCollection{}, err } - pg := paramGroupedSlice{ + isMap := f.Type.Kind() == reflect.Map && f.Type.Key().Kind() == reflect.String + isSlice := f.Type.Kind() == reflect.Slice + pg := paramGroupedCollection{ Group: g.Name, Type: f.Type, + isMap: isMap, orders: make(map[*Scope]int), Soft: g.Soft, } @@ -540,9 +547,9 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { - case f.Type.Kind() != reflect.Slice: + case !isMap && !isSlice: return pg, newErrInvalidInput( - fmt.Sprintf("value groups may be consumed as slices only: field %q (%v) is not a slice", f.Name, f.Type), nil) + fmt.Sprintf("value groups may be consumed as slices or string-keyed maps only: field %q (%v) is not a slice or string-keyed map", f.Name, f.Type), nil) case g.Flatten: return pg, newErrInvalidInput( fmt.Sprintf("cannot use flatten in parameter value groups: field %q (%v) specifies flatten", f.Name, f.Type), nil) @@ -560,7 +567,7 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped // any of the parent Scopes. In the case where there are multiple scopes that // are decorating the same type, the closest scope in effect will be replacing // any decorated value groups provided in further scopes. -func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, bool) { +func (pt paramGroupedCollection) getDecoratedValues(c containerStore) (reflect.Value, bool) { for _, c := range c.storesToRoot() { if items, ok := c.getDecoratedValueGroup(pt.Group, pt.Type); ok { return items, true @@ -575,7 +582,7 @@ func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, // The order in which the decorators are invoked is from the top level scope to // the current scope, to account for decorators that decorate values that were // already decorated. -func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { +func (pt paramGroupedCollection) callGroupDecorators(c containerStore) error { stores := c.storesToRoot() for i := len(stores) - 1; i >= 0; i-- { c := stores[i] @@ -600,7 +607,7 @@ func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { // search the given container and its parent for matching group providers and // call them to commit values. If an error is encountered, return the number // of providers called and a non-nil error from the first provided. -func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { +func (pt paramGroupedCollection) callGroupProviders(c containerStore) (int, error) { itemCount := 0 for _, c := range c.storesToRoot() { providers := c.getGroupProviders(pt.Group, pt.Type.Elem()) @@ -618,7 +625,7 @@ func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { return itemCount, nil } -func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { +func (pt paramGroupedCollection) Build(c containerStore) (reflect.Value, error) { // do not call this if we are already inside a decorator since // it will result in an infinite recursion. (i.e. decorate -> params.BuildList() -> Decorate -> params.BuildList...) // this is safe since a value can be decorated at most once in a given scope. @@ -644,6 +651,22 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { } stores := c.storesToRoot() + if pt.isMap { + result := reflect.MakeMapWithSize(pt.Type, itemCount) + for _, c := range stores { + kgvs := c.getValueGroup(pt.Group, pt.Type.Elem()) + for _, kgv := range kgvs { + if kgv.key == "" { + return _noValue, newErrInvalidInput( + fmt.Sprintf("every entry in a map value groups must have a name, group \"%v\" is missing a name", pt.Group), + nil, + ) + } + result.SetMapIndex(reflect.ValueOf(kgv.key), kgv.value) + } + } + return result, nil + } result := reflect.MakeSlice(pt.Type, 0, itemCount) for _, c := range stores { kgvs := c.getValueGroup(pt.Group, pt.Type.Elem()) diff --git a/param_test.go b/param_test.go index 7a1f41ed..ce9e5c4d 100644 --- a/param_test.go +++ b/param_test.go @@ -179,21 +179,31 @@ func TestParamObjectFailure(t *testing.T) { }) } -func TestParamGroupSliceErrors(t *testing.T) { +func TestParamGroupCollectionErrors(t *testing.T) { tests := []struct { desc string shape interface{} wantErr string }{ { - desc: "non-slice type are disallowed", + desc: "non-slice or string-keyed map type are disallowed (slice)", shape: struct { In Foo string `group:"foo"` }{}, - wantErr: "value groups may be consumed as slices only: " + - `field "Foo" (string) is not a slice`, + wantErr: "value groups may be consumed as slices or string-keyed maps only: " + + `field "Foo" (string) is not a slice or string-keyed map`, + }, + { + desc: "non-slice or string-keyed map type are disallowed (string-keyed map)", + shape: struct { + In + + Foo map[int]int `group:"foo"` + }{}, + wantErr: "value groups may be consumed as slices or string-keyed maps only: " + + `field "Foo" (map[int]int) is not a slice or string-keyed map`, }, { desc: "cannot provide name for a group", diff --git a/result.go b/result.go index 22c537dd..24eb7d8d 100644 --- a/result.go +++ b/result.go @@ -87,7 +87,7 @@ func newResult(t reflect.Type, opts resultOptions, noGroup bool) (result, error) return nil, newErrInvalidInput( fmt.Sprintf("cannot parse group %q", opts.Group), err) } - rg := resultGrouped{Type: t, Group: g.Name, Flatten: g.Flatten} + rg := resultGrouped{Type: t, Key: opts.Name, Group: g.Name, Flatten: g.Flatten} if len(opts.As) > 0 { var asTypes []reflect.Type for _, as := range opts.As {