diff --git a/dig_test.go b/dig_test.go index 72381c43..84c2855c 100644 --- a/dig_test.go +++ b/dig_test.go @@ -1745,6 +1745,84 @@ func TestGroups(t *testing.T) { assert.ElementsMatch(t, []string{"a"}, param.Value) }) }) + /* map tests */ + t.Run("empty map received without provides", func(t *testing.T) { + c := digtest.New(t) + + type in struct { + dig.In + + Values map[string]int `group:"foo"` + } + + c.RequireInvoke(func(i in) { + require.Empty(t, i.Values) + }) + }) + t.Run("values are provided, map and name and slice", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + type out struct { + dig.Out + + Value1 int `name:"value1" group:"val"` + Value2 int `name:"value2" group:"val"` + Value3 int `name:"value3" group:"val"` + } + + c.RequireProvide(func() out { + return out{Value1: 1, Value2: 2, Value3: 3} + }) + + type in struct { + dig.In + + Value1 int `name:"value1"` + Value2 int `name:"value2"` + Value3 int `name:"value3"` + Values []int `group:"val"` + ValueMap map[string]int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + assert.Equal(t, i.ValueMap["value1"], 1) + assert.Equal(t, i.ValueMap["value2"], 2) + assert.Equal(t, i.ValueMap["value3"], 3) + assert.Equal(t, i.Value1, 1) + assert.Equal(t, i.Value2, 2) + assert.Equal(t, i.Value3, 3) + }) + }) + + t.Run("Every item used in a map must have a named key", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out struct { + dig.Out + + Value1 int `name:"value1" group:"val"` + Value2 int `name:"value2" group:"val"` + Value3 int `group:"val"` + } + + c.RequireProvide(func() out { + return out{Value1: 1, Value2: 2, Value3: 3} + }) + + type in struct { + dig.In + + ValueMap map[string]int `group:"val"` + } + var called = false + err := c.Invoke(func(i in) { called = true }) + dig.AssertErrorMatches(t, err, + `could not build arguments for function "go.uber.org/dig_test".TestGroups\S+`, + `dig_test.go:\d+`, // file:line + `every entry in a map value groups must have a name, group "val" is missing a name`) + assert.False(t, called, "shouldn't call invoked function when deps aren't available") + }) + } // --- END OF END TO END TESTS diff --git a/graph.go b/graph.go index e08f1f54..c52e78f7 100644 --- a/graph.go +++ b/graph.go @@ -28,7 +28,7 @@ type graphNode struct { } // graphHolder is the dependency graph of the container. -// It saves constructorNodes and paramGroupedSlice (value groups) +// It saves constructorNodes and paramGroupedCollection (value groups) // as nodes in the graph. // It implements the graph interface defined by internal/graph. // It has 1-1 correspondence with the Scope whose graph it represents. @@ -68,7 +68,7 @@ func (gh *graphHolder) EdgesFrom(u int) []int { for _, param := range w.paramList.Params { orders = append(orders, getParamOrder(gh, param)...) } - case *paramGroupedSlice: + case *paramGroupedCollection: providers := gh.s.getAllGroupProviders(w.Group, w.Type.Elem()) for _, provider := range providers { orders = append(orders, provider.Order(gh.s)) diff --git a/param.go b/param.go index e60501e1..e44b3b5c 100644 --- a/param.go +++ b/param.go @@ -38,10 +38,12 @@ import ( // paramSingle An explicitly requested type. // paramObject dig.In struct where each field in the struct can be another // param. -// paramGroupedSlice -// A slice consuming a value group. This will receive all +// paramGroupedCollection +// A slice or map consuming a value group. This will receive all // values produced with a `group:".."` tag with the same name -// as a slice. +// as a slice or map. For a map, every value produced with the +// same group name MUST have a name which will form the map key. + type param interface { fmt.Stringer @@ -59,7 +61,7 @@ var ( _ param = paramSingle{} _ param = paramObject{} _ param = paramList{} - _ param = paramGroupedSlice{} + _ param = paramGroupedCollection{} ) // newParam builds a param from the given type. If the provided type is a @@ -342,7 +344,7 @@ func getParamOrder(gh *graphHolder, param param) []int { for _, provider := range providers { orders = append(orders, provider.Order(gh.s)) } - case paramGroupedSlice: + case paramGroupedCollection: // value group parameters have nodes of their own. // We can directly return that here. orders = append(orders, p.orders[gh.s]) @@ -401,7 +403,7 @@ func (po paramObject) Build(c containerStore) (reflect.Value, error) { var softGroupsQueue []paramObjectField var fields []paramObjectField for _, f := range po.Fields { - if p, ok := f.Param.(paramGroupedSlice); ok && p.Soft { + if p, ok := f.Param.(paramGroupedCollection); ok && p.Soft { softGroupsQueue = append(softGroupsQueue, f) continue } @@ -451,7 +453,7 @@ func newParamObjectField(idx int, f reflect.StructField, c containerStore) (para case f.Tag.Get(_groupTag) != "": var err error - p, err = newParamGroupedSlice(f, c) + p, err = newParamGroupedCollection(f, c) if err != nil { return pof, err } @@ -488,13 +490,13 @@ func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) { return v, nil } -// paramGroupedSlice is a param which produces a slice of values with the same +// paramGroupedCollection is a param which produces a slice or map of values with the same // group name. -type paramGroupedSlice struct { +type paramGroupedCollection struct { // Name of the group as specified in the `group:".."` tag. Group string - // Type of the slice. + // Type of the map or slice. Type reflect.Type // Soft is used to denote a soft dependency between this param and its @@ -502,15 +504,17 @@ type paramGroupedSlice struct { // provide another value requested in the graph Soft bool + isMap bool orders map[*Scope]int } -func (pt paramGroupedSlice) String() string { +func (pt paramGroupedCollection) String() string { // io.Reader[group="foo"] refers to a group of io.Readers called 'foo' return fmt.Sprintf("%v[group=%q]", pt.Type.Elem(), pt.Group) + // JQTODO, different string for map } -func (pt paramGroupedSlice) DotParam() []*dot.Param { +func (pt paramGroupedCollection) DotParam() []*dot.Param { return []*dot.Param{ { Node: &dot.Node{ @@ -521,18 +525,21 @@ func (pt paramGroupedSlice) DotParam() []*dot.Param { } } -// newParamGroupedSlice builds a paramGroupedSlice from the provided type with +// newParamGroupedCollection builds a paramGroupedCollection from the provided type with // the given name. // -// The type MUST be a slice type. -func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGroupedSlice, error) { +// The type MUST be a slice or map[string]T type. +func newParamGroupedCollection(f reflect.StructField, c containerStore) (paramGroupedCollection, error) { g, err := parseGroupString(f.Tag.Get(_groupTag)) if err != nil { - return paramGroupedSlice{}, err + return paramGroupedCollection{}, err } - pg := paramGroupedSlice{ + isMap := f.Type.Kind() == reflect.Map && f.Type.Key().Kind() == reflect.String + isSlice := f.Type.Kind() == reflect.Slice + pg := paramGroupedCollection{ Group: g.Name, Type: f.Type, + isMap: isMap, orders: make(map[*Scope]int), Soft: g.Soft, } @@ -540,9 +547,9 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { - case f.Type.Kind() != reflect.Slice: + case !isMap && !isSlice: return pg, newErrInvalidInput( - fmt.Sprintf("value groups may be consumed as slices only: field %q (%v) is not a slice", f.Name, f.Type), nil) + fmt.Sprintf("value groups may be consumed as slices or string-keyed maps only: field %q (%v) is not a slice or string-keyed map", f.Name, f.Type), nil) case g.Flatten: return pg, newErrInvalidInput( fmt.Sprintf("cannot use flatten in parameter value groups: field %q (%v) specifies flatten", f.Name, f.Type), nil) @@ -560,7 +567,7 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped // any of the parent Scopes. In the case where there are multiple scopes that // are decorating the same type, the closest scope in effect will be replacing // any decorated value groups provided in further scopes. -func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, bool) { +func (pt paramGroupedCollection) getDecoratedValues(c containerStore) (reflect.Value, bool) { for _, c := range c.storesToRoot() { if items, ok := c.getDecoratedValueGroup(pt.Group, pt.Type); ok { return items, true @@ -575,7 +582,7 @@ func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, // The order in which the decorators are invoked is from the top level scope to // the current scope, to account for decorators that decorate values that were // already decorated. -func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { +func (pt paramGroupedCollection) callGroupDecorators(c containerStore) error { stores := c.storesToRoot() for i := len(stores) - 1; i >= 0; i-- { c := stores[i] @@ -600,7 +607,7 @@ func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { // search the given container and its parent for matching group providers and // call them to commit values. If an error is encountered, return the number // of providers called and a non-nil error from the first provided. -func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { +func (pt paramGroupedCollection) callGroupProviders(c containerStore) (int, error) { itemCount := 0 for _, c := range c.storesToRoot() { providers := c.getGroupProviders(pt.Group, pt.Type.Elem()) @@ -618,7 +625,7 @@ func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { return itemCount, nil } -func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { +func (pt paramGroupedCollection) Build(c containerStore) (reflect.Value, error) { // do not call this if we are already inside a decorator since // it will result in an infinite recursion. (i.e. decorate -> params.BuildList() -> Decorate -> params.BuildList...) // this is safe since a value can be decorated at most once in a given scope. @@ -644,6 +651,22 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { } stores := c.storesToRoot() + if pt.isMap { + result := reflect.MakeMapWithSize(pt.Type, itemCount) + for _, c := range stores { + kgvs := c.getValueGroup(pt.Group, pt.Type.Elem()) + for _, kgv := range kgvs { + if kgv.key == "" { + return _noValue, newErrInvalidInput( + fmt.Sprintf("every entry in a map value groups must have a name, group \"%v\" is missing a name", pt.Group), + nil, + ) + } + result.SetMapIndex(reflect.ValueOf(kgv.key), kgv.value) + } + } + return result, nil + } result := reflect.MakeSlice(pt.Type, 0, itemCount) for _, c := range stores { kgvs := c.getValueGroup(pt.Group, pt.Type.Elem()) diff --git a/param_test.go b/param_test.go index 7a1f41ed..ce9e5c4d 100644 --- a/param_test.go +++ b/param_test.go @@ -179,21 +179,31 @@ func TestParamObjectFailure(t *testing.T) { }) } -func TestParamGroupSliceErrors(t *testing.T) { +func TestParamGroupCollectionErrors(t *testing.T) { tests := []struct { desc string shape interface{} wantErr string }{ { - desc: "non-slice type are disallowed", + desc: "non-slice or string-keyed map type are disallowed (slice)", shape: struct { In Foo string `group:"foo"` }{}, - wantErr: "value groups may be consumed as slices only: " + - `field "Foo" (string) is not a slice`, + wantErr: "value groups may be consumed as slices or string-keyed maps only: " + + `field "Foo" (string) is not a slice or string-keyed map`, + }, + { + desc: "non-slice or string-keyed map type are disallowed (string-keyed map)", + shape: struct { + In + + Foo map[int]int `group:"foo"` + }{}, + wantErr: "value groups may be consumed as slices or string-keyed maps only: " + + `field "Foo" (map[int]int) is not a slice or string-keyed map`, }, { desc: "cannot provide name for a group", diff --git a/result.go b/result.go index 22c537dd..24eb7d8d 100644 --- a/result.go +++ b/result.go @@ -87,7 +87,7 @@ func newResult(t reflect.Type, opts resultOptions, noGroup bool) (result, error) return nil, newErrInvalidInput( fmt.Sprintf("cannot parse group %q", opts.Group), err) } - rg := resultGrouped{Type: t, Group: g.Name, Flatten: g.Flatten} + rg := resultGrouped{Type: t, Key: opts.Name, Group: g.Name, Flatten: g.Flatten} if len(opts.As) > 0 { var asTypes []reflect.Type for _, as := range opts.As {