From 5a5b50fe5b81e137f79e831bc20f16403596664d Mon Sep 17 00:00:00 2001 From: Jeremy Quirke Date: Sun, 5 Mar 2023 22:09:07 -0800 Subject: [PATCH 1/3] Support simultanous name and group tags As per Dig issue: https://github.com/uber-go/dig/issues/380 In order to support Fx feature requests https://github.com/uber-go/fx/issues/998 https://github.com/uber-go/fx/issues/1036 We need to be able to drop the restriction, both in terms of options dig.Name and dig.Group and dig.Out struct annotations on `name` and `group` being mutually exclusive. In a future PR, this can then be exploited to populate value group maps where the 'name' tag becomes the key of a map[string][T] --- decorate.go | 2 +- dig_test.go | 251 ++++++++++++++++++++++++++++++++++++++++++++++--- provide.go | 6 -- result.go | 93 ++++++++++++------ result_test.go | 54 ++++++----- 5 files changed, 332 insertions(+), 74 deletions(-) diff --git a/decorate.go b/decorate.go index 8c4105e9..3a7114c5 100644 --- a/decorate.go +++ b/decorate.go @@ -288,7 +288,7 @@ func findResultKeys(r resultList) ([]key, error) { keys = append(keys, key{t: innerResult.Type.Elem(), group: innerResult.Group}) case resultObject: for _, f := range innerResult.Fields { - q = append(q, f.Result) + q = append(q, f.Results...) } case resultList: q = append(q, innerResult.Results...) diff --git a/dig_test.go b/dig_test.go index 69b10f9d..72381c43 100644 --- a/dig_test.go +++ b/dig_test.go @@ -749,6 +749,53 @@ func TestEndToEndSuccess(t *testing.T) { assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match") }) + t.Run("multiple As with Group and Name", func(t *testing.T) { + c := digtest.New(t) + expectedNames := []string{"inst1", "inst2"} + expectedStrs := []string{"foo", "bar"} + for i, s := range expectedStrs { + s := s + c.RequireProvide(func() *bytes.Buffer { + return bytes.NewBufferString(s) + }, dig.Group("buffs"), dig.Name(expectedNames[i]), + dig.As(new(io.Reader), new(io.Writer))) + } + + type in struct { + dig.In + + Reader1 io.Reader `name:"inst1"` + Reader2 io.Reader `name:"inst2"` + Readers []io.Reader `group:"buffs"` + Writers []io.Writer `group:"buffs"` + } + + var actualStrs []string + var actualStrsName []string + + c.RequireInvoke(func(got in) { + require.Len(t, got.Readers, 2) + buf := make([]byte, 3) + for i, r := range got.Readers { + _, err := r.Read(buf) + require.NoError(t, err) + actualStrs = append(actualStrs, string(buf)) + // put the text back + got.Writers[i].Write(buf) + } + _, err := got.Reader1.Read(buf) + require.NoError(t, err) + actualStrsName = append(actualStrsName, string(buf)) + _, err = got.Reader2.Read(buf) + require.NoError(t, err) + actualStrsName = append(actualStrsName, string(buf)) + require.Len(t, got.Writers, 2) + }) + + assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match") + assert.ElementsMatch(t, actualStrsName, expectedStrs, "names: list of strings provided must match") + }) + t.Run("As same interface", func(t *testing.T) { c := digtest.New(t) c.RequireProvide(func() io.Reader { @@ -1098,6 +1145,48 @@ func TestGroups(t *testing.T) { }) }) + t.Run("values are provided; coexist with name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out struct { + dig.Out + + Value int `group:"val"` + } + + type out2 struct { + dig.Out + + Value int `name:"inst1" group:"val"` + } + + provide := func(i int) { + c.RequireProvide(func() out { + return out{Value: i} + }) + } + + provide(1) + provide(2) + provide(3) + + c.RequireProvide(func() out2 { + return out2{Value: 4} + }) + + type in struct { + dig.In + + SingleValue int `name:"inst1"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{1, 2, 3, 4}, i.Values) + assert.Equal(t, 4, i.SingleValue) + }) + }) + t.Run("groups are provided via option", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) @@ -1122,6 +1211,36 @@ func TestGroups(t *testing.T) { }) }) + t.Run("groups are provided via option; coexist with name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + provide := func(i int) { + c.RequireProvide(func() int { + return i + }, dig.Group("val")) + } + + provide(1) + provide(2) + provide(3) + + c.RequireProvide(func() int { + return 4 + }, dig.Group("val"), dig.Name("inst1")) + + type in struct { + dig.In + + SingleValue int `name:"inst1"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{1, 2, 3, 4}, i.Values) + assert.Equal(t, 4, i.SingleValue) + }) + }) + t.Run("different types may be grouped", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) @@ -1429,6 +1548,44 @@ func TestGroups(t *testing.T) { }) }) + t.Run("flatten collects slices but also handles name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out1 struct { + dig.Out + + Value []int `name:"foo1" group:"val,flatten"` + } + + type out2 struct { + dig.Out + + Value []int `name:"foo2" group:"val,flatten"` + } + + c.RequireProvide(func() out1 { + return out1{Value: []int{1, 2}} + }) + + c.RequireProvide(func() out2 { + return out2{Value: []int{3, 4}} + }) + + type in struct { + dig.In + + NotFlattenedSlice1 []int `name:"foo1"` + NotFlattenedSlice2 []int `name:"foo2"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 4, 1}, i.Values) + assert.Equal(t, []int{1, 2}, i.NotFlattenedSlice1) + assert.Equal(t, []int{3, 4}, i.NotFlattenedSlice2) + }) + }) + t.Run("flatten via option", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) c.RequireProvide(func() []int { @@ -1446,6 +1603,31 @@ func TestGroups(t *testing.T) { }) }) + t.Run("flatten via option also handles name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + c.RequireProvide(func() []int { + return []int{1, 2} + }, dig.Group("val,flatten"), dig.Name("foo1")) + + c.RequireProvide(func() []int { + return []int{3} + }, dig.Group("val,flatten"), dig.Name("foo2")) + + type in struct { + dig.In + + NotFlattenedSlice1 []int `name:"foo1"` + NotFlattenedSlice2 []int `name:"foo2"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + assert.Equal(t, []int{1, 2}, i.NotFlattenedSlice1) + assert.Equal(t, []int{3}, i.NotFlattenedSlice2) + }) + }) + t.Run("flatten via option error if not a slice", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) err := c.Provide(func() int { return 1 }, dig.Group("val,flatten")) @@ -1998,21 +2180,6 @@ func TestAsExpectingOriginalType(t *testing.T) { }) } -func TestProvideIncompatibleOptions(t *testing.T) { - t.Parallel() - - t.Run("group and name", func(t *testing.T) { - c := digtest.New(t) - err := c.Provide(func() io.Reader { - t.Fatal("this function must not be called") - return nil - }, dig.Group("foo"), dig.Name("bar")) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot use named values with value groups: "+ - `name:"bar" provided with group:"foo"`) - }) -} - type testStruct struct{} func (testStruct) TestMethod(x int) float64 { return float64(x) } @@ -2559,6 +2726,60 @@ func testProvideFailures(t *testing.T, dryRun bool) { ) }) + t.Run("provide multiple instances with the same name and same group", func(t *testing.T) { + c := digtest.New(t, dig.DryRun(dryRun)) + type A struct{} + type ret1 struct { + dig.Out + *A `name:"foo" group:"foos"` + } + type ret2 struct { + dig.Out + *A `name:"foo" group:"foos"` + } + c.RequireProvide(func() ret1 { + return ret1{A: &A{}} + }) + + err := c.Provide(func() ret2 { + return ret2{A: &A{}} + }) + require.Error(t, err, "expected error on the second provide") + dig.AssertErrorMatches(t, err, + `cannot provide function "go.uber.org/dig_test".testProvideFailures\S+`, + `dig_test.go:\d+`, // file:line + `cannot provide \*dig_test.A\[name="foo"\] from \[0\].A:`, + `already provided by "go.uber.org/dig_test".testProvideFailures\S+`, + ) + }) + + t.Run("provide multiple instances with the same name but different group", func(t *testing.T) { + c := digtest.New(t, dig.DryRun(dryRun)) + type A struct{} + type ret1 struct { + dig.Out + *A `name:"foo" group:"foos"` + } + type ret2 struct { + dig.Out + *A `name:"foo" group:"foosss"` + } + c.RequireProvide(func() ret1 { + return ret1{A: &A{}} + }) + + err := c.Provide(func() ret2 { + return ret2{A: &A{}} + }) + require.Error(t, err, "expected error on the second provide") + dig.AssertErrorMatches(t, err, + `cannot provide function "go.uber.org/dig_test".testProvideFailures\S+`, + `dig_test.go:\d+`, // file:line + `cannot provide \*dig_test.A\[name="foo"\] from \[0\].A:`, + `already provided by "go.uber.org/dig_test".testProvideFailures\S+`, + ) + }) + t.Run("out with unexported field should error", func(t *testing.T) { c := digtest.New(t, dig.DryRun(dryRun)) diff --git a/provide.go b/provide.go index 91a4a920..277c1b3d 100644 --- a/provide.go +++ b/provide.go @@ -46,12 +46,6 @@ type provideOptions struct { } func (o *provideOptions) Validate() error { - if len(o.Group) > 0 { - if len(o.Name) > 0 { - return newErrInvalidInput( - fmt.Sprintf("cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group), nil) - } - } // Names must be representable inside a backquoted string. The only // limitation for raw string literals as per diff --git a/result.go b/result.go index 369cd218..a4e5ba38 100644 --- a/result.go +++ b/result.go @@ -66,7 +66,7 @@ type resultOptions struct { } // newResult builds a result from the given type. -func newResult(t reflect.Type, opts resultOptions) (result, error) { +func newResult(t reflect.Type, opts resultOptions, noGroup bool) (result, error) { switch { case IsIn(t) || (t.Kind() == reflect.Ptr && IsIn(t.Elem())) || embedsType(t, _inPtrType): return nil, newErrInvalidInput(fmt.Sprintf( @@ -81,7 +81,7 @@ func newResult(t reflect.Type, opts resultOptions) (result, error) { case t.Kind() == reflect.Ptr && IsOut(t.Elem()): return nil, newErrInvalidInput(fmt.Sprintf( "cannot return a pointer to a result object, use a value instead: %v is a pointer to a struct that embeds dig.Out", t), nil) - case len(opts.Group) > 0: + case len(opts.Group) > 0 && !noGroup: g, err := parseGroupString(opts.Group) if err != nil { return nil, newErrInvalidInput( @@ -176,7 +176,9 @@ func walkResult(r result, v resultVisitor) { w := v for _, f := range res.Fields { if v := w.AnnotateWithField(f); v != nil { - walkResult(f.Result, v) + for _, r := range f.Results { + walkResult(r, v) + } } } case resultList: @@ -200,7 +202,7 @@ type resultList struct { // For each item at index i returned by the constructor, resultIndexes[i] // is the index in .Results for the corresponding result object. // resultIndexes[i] is -1 for errors returned by constructors. - resultIndexes []int + resultIndexes [][]int } func (rl resultList) DotResult() []*dot.Result { @@ -216,25 +218,45 @@ func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) { rl := resultList{ ctype: ctype, Results: make([]result, 0, numOut), - resultIndexes: make([]int, numOut), + resultIndexes: make([][]int, numOut), } resultIdx := 0 for i := 0; i < numOut; i++ { t := ctype.Out(i) if isError(t) { - rl.resultIndexes[i] = -1 + rl.resultIndexes[i] = append(rl.resultIndexes[i], -1) continue } - r, err := newResult(t, opts) - if err != nil { - return rl, newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err) + addResult := func(nogroup bool) error { + r, err := newResult(t, opts, nogroup) + if err != nil { + return newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err) + } + + rl.Results = append(rl.Results, r) + rl.resultIndexes[i] = append(rl.resultIndexes[i], resultIdx) + resultIdx++ + return nil + } + + // special case, its added as a group and a name using options alone + if len(opts.Name) > 0 && len(opts.Group) > 0 && !IsOut(t) { + // add as a group + if err := addResult(false); err != nil { + return rl, err + } + // add as single + err := addResult(true) + return rl, err + } + + // add as normal + if err := addResult(false); err != nil { + return rl, err } - rl.Results = append(rl.Results, r) - rl.resultIndexes[i] = resultIdx - resultIdx++ } return rl, nil @@ -246,8 +268,14 @@ func (resultList) Extract(containerWriter, bool, reflect.Value) { func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []reflect.Value) error { for i, v := range values { - if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 { - rl.Results[resultIdx].Extract(cw, decorated, v) + isNonErrorResult := false + for _, resultIdx := range rl.resultIndexes[i] { + if resultIdx >= 0 { + rl.Results[resultIdx].Extract(cw, decorated, v) + isNonErrorResult = true + } + } + if isNonErrorResult { continue } @@ -384,7 +412,9 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) { for _, f := range ro.Fields { - f.Result.Extract(cw, decorated, v.Field(f.FieldIndex)) + for _, r := range f.Results { + r.Extract(cw, decorated, v.Field(f.FieldIndex)) + } } } @@ -399,12 +429,16 @@ type resultObjectField struct { // map to results. FieldIndex int - // Result produced by this field. - Result result + // Results produced by this field. + Results []result } func (rof resultObjectField) DotResult() []*dot.Result { - return rof.Result.DotResult() + results := make([]*dot.Result, 0, len(rof.Results)) + for _, r := range rof.Results { + results = append(results, r.DotResult()...) + } + return results } // newResultObjectField(i, f, opts) builds a resultObjectField from the field @@ -414,7 +448,11 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r FieldName: f.Name, FieldIndex: idx, } - + name := f.Tag.Get(_nameTag) + if len(name) > 0 { + // can modify in-place because options are passed-by-value. + opts.Name = name + } var r result switch { case f.PkgPath != "": @@ -427,20 +465,21 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r if err != nil { return rof, err } + rof.Results = append(rof.Results, r) + if len(name) == 0 { + break + } + fallthrough default: var err error - if name := f.Tag.Get(_nameTag); len(name) > 0 { - // can modify in-place because options are passed-by-value. - opts.Name = name - } - r, err = newResult(f.Type, opts) + r, err = newResult(f.Type, opts, false) if err != nil { return rof, err } + rof.Results = append(rof.Results, r) } - rof.Result = r return rof, nil } @@ -493,7 +532,6 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { Flatten: g.Flatten, Type: f.Type, } - name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { case g.Flatten && f.Type.Kind() != reflect.Slice: @@ -502,9 +540,6 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { case g.Soft: return rg, newErrInvalidInput(fmt.Sprintf( "cannot use soft with result value groups: soft was used with group %q", rg.Group), nil) - case name != "": - return rg, newErrInvalidInput(fmt.Sprintf( - "cannot use named values with value groups: name:%q provided with group:%q", name, rg.Group), nil) case optional: return rg, newErrInvalidInput("value groups cannot be optional", nil) } diff --git a/result_test.go b/result_test.go index c19db20d..db58ce2b 100644 --- a/result_test.go +++ b/result_test.go @@ -108,7 +108,7 @@ func TestNewResultErrors(t *testing.T) { for _, tt := range tests { give := reflect.TypeOf(tt.give) t.Run(fmt.Sprint(give), func(t *testing.T) { - _, err := newResult(give, resultOptions{}) + _, err := newResult(give, resultOptions{}, false) require.Error(t, err) assert.Contains(t, err.Error(), tt.err) }) @@ -139,12 +139,12 @@ func TestNewResultObject(t *testing.T) { { FieldName: "Reader", FieldIndex: 1, - Result: resultSingle{Type: typeOfReader}, + Results: []result{resultSingle{Type: typeOfReader}}, }, { FieldName: "Writer", FieldIndex: 2, - Result: resultSingle{Type: typeOfWriter}, + Results: []result{resultSingle{Type: typeOfWriter}}, }, }, }, @@ -160,12 +160,12 @@ func TestNewResultObject(t *testing.T) { { FieldName: "A", FieldIndex: 1, - Result: resultSingle{Name: "stream-a", Type: typeOfWriter}, + Results: []result{resultSingle{Name: "stream-a", Type: typeOfWriter}}, }, { FieldName: "B", FieldIndex: 2, - Result: resultSingle{Name: "stream-b", Type: typeOfWriter}, + Results: []result{resultSingle{Name: "stream-b", Type: typeOfWriter}}, }, }, }, @@ -180,7 +180,25 @@ func TestNewResultObject(t *testing.T) { { FieldName: "Writer", FieldIndex: 1, - Result: resultGrouped{Group: "writers", Type: typeOfWriter}, + Results: []result{resultGrouped{Group: "writers", Type: typeOfWriter}}, + }, + }, + }, + { + desc: "group and name tag", + give: struct { + Out + + Writer io.Writer `name:"writer1" group:"writers"` + }{}, + wantFields: []resultObjectField{ + { + FieldName: "Writer", + FieldIndex: 1, + Results: []result{ + resultGrouped{Group: "writers", Type: typeOfWriter}, + resultSingle{Name: "writer1", Type: typeOfWriter}, + }, }, }, }, @@ -229,16 +247,6 @@ func TestNewResultObjectErrors(t *testing.T) { }{}, err: `bad field "Nested"`, }, - { - desc: "group with name should fail", - give: struct { - Out - - Foo string `group:"foo" name:"bar"` - }{}, - err: "cannot use named values with value groups: " + - `name:"bar" provided with group:"foo"`, - }, { desc: "group marked as optional", give: struct { @@ -414,31 +422,31 @@ func TestWalkResult(t *testing.T) { { AnnotateWithField: &ro.Fields[0], Return: fakeResultVisits{ - {Visit: ro.Fields[0].Result}, + {Visit: ro.Fields[0].Results[0]}, }, }, { AnnotateWithField: &ro.Fields[1], Return: fakeResultVisits{ - {Visit: ro.Fields[1].Result}, + {Visit: ro.Fields[1].Results[0]}, }, }, { AnnotateWithField: &ro.Fields[2], Return: fakeResultVisits{ { - Visit: ro.Fields[2].Result, + Visit: ro.Fields[2].Results[0], Return: fakeResultVisits{ { - AnnotateWithField: &ro.Fields[2].Result.(resultObject).Fields[0], + AnnotateWithField: &ro.Fields[2].Results[0].(resultObject).Fields[0], Return: fakeResultVisits{ - {Visit: ro.Fields[2].Result.(resultObject).Fields[0].Result}, + {Visit: ro.Fields[2].Results[0].(resultObject).Fields[0].Results[0]}, }, }, { - AnnotateWithField: &ro.Fields[2].Result.(resultObject).Fields[1], + AnnotateWithField: &ro.Fields[2].Results[0].(resultObject).Fields[1], Return: fakeResultVisits{ - {Visit: ro.Fields[2].Result.(resultObject).Fields[1].Result}, + {Visit: ro.Fields[2].Results[0].(resultObject).Fields[1].Results[0]}, }, }, }, From e78175786554fa3f6acb93b02b34c6958b49ae83 Mon Sep 17 00:00:00 2001 From: Jeremy Quirke Date: Mon, 6 Mar 2023 18:34:10 -0800 Subject: [PATCH 2/3] Track the key of any named group objects. As part of https://github.com/uber-go/dig/issues/380 we allowed names and groups tags/options to co-exist to ultimately support Fx feature request https://github.com/uber-go/fx/issues/998. We now intend to support Map value groups as per https://github.com/uber-go/fx/issues/1036. We will do this in 2 steps. 1. This PR will begin tracking any names passed into value groups with out changing any external facing functionality. 2. a subsequent PR will exploit this structure to support Map value groups. --- constructor.go | 16 ++++++++-------- container.go | 14 +++++++------- param.go | 6 +++++- result.go | 14 ++++++++++---- result_test.go | 2 +- scope.go | 31 ++++++++++++++++++------------- 6 files changed, 49 insertions(+), 34 deletions(-) diff --git a/constructor.go b/constructor.go index 7cf0c8ef..9dd0ff39 100644 --- a/constructor.go +++ b/constructor.go @@ -181,7 +181,7 @@ func (n *constructorNode) Call(c containerStore) (err error) { // would be made to a containerWriter and defers them until Commit is called. type stagingContainerWriter struct { values map[key]reflect.Value - groups map[key][]reflect.Value + groups map[key][]keyedGroupValue } var _ containerWriter = (*stagingContainerWriter)(nil) @@ -189,7 +189,7 @@ var _ containerWriter = (*stagingContainerWriter)(nil) func newStagingContainerWriter() *stagingContainerWriter { return &stagingContainerWriter{ values: make(map[key]reflect.Value), - groups: make(map[key][]reflect.Value), + groups: make(map[key][]keyedGroupValue), } } @@ -201,12 +201,12 @@ func (sr *stagingContainerWriter) setDecoratedValue(_ string, _ reflect.Type, _ digerror.BugPanicf("stagingContainerWriter.setDecoratedValue must never be called") } -func (sr *stagingContainerWriter) submitGroupedValue(group string, t reflect.Type, v reflect.Value) { +func (sr *stagingContainerWriter) submitGroupedValue(group, mapKey string, t reflect.Type, v reflect.Value) { k := key{t: t, group: group} - sr.groups[k] = append(sr.groups[k], v) + sr.groups[k] = append(sr.groups[k], keyedGroupValue{key: mapKey, value: v}) } -func (sr *stagingContainerWriter) submitDecoratedGroupedValue(_ string, _ reflect.Type, _ reflect.Value) { +func (sr *stagingContainerWriter) submitDecoratedGroupedValue(_, _ string, _ reflect.Type, _ reflect.Value) { digerror.BugPanicf("stagingContainerWriter.submitDecoratedGroupedValue must never be called") } @@ -216,9 +216,9 @@ func (sr *stagingContainerWriter) Commit(cw containerWriter) { cw.setValue(k.name, k.t, v) } - for k, vs := range sr.groups { - for _, v := range vs { - cw.submitGroupedValue(k.group, k.t, v) + for k, kgvs := range sr.groups { + for _, kgv := range kgvs { + cw.submitGroupedValue(k.group, kgv.key, k.t, kgv.value) } } } diff --git a/container.go b/container.go index 983fd3f9..4fd6a238 100644 --- a/container.go +++ b/container.go @@ -81,12 +81,12 @@ type containerWriter interface { setDecoratedValue(name string, t reflect.Type, v reflect.Value) // submitGroupedValue submits a value to the value group with the provided - // name. - submitGroupedValue(name string, t reflect.Type, v reflect.Value) + // name and optional map key. + submitGroupedValue(name, mapKey string, t reflect.Type, v reflect.Value) // submitDecoratedGroupedValue submits a decorated value to the value group - // with the provided name. - submitDecoratedGroupedValue(name string, t reflect.Type, v reflect.Value) + // with the provided name and optional map key. + submitDecoratedGroupedValue(name, mapKey string, t reflect.Type, v reflect.Value) } // containerStore provides access to the Container's underlying data store. @@ -108,7 +108,7 @@ type containerStore interface { // Retrieves all values for the provided group and type. // // The order in which the values are returned is undefined. - getValueGroup(name string, t reflect.Type) []reflect.Value + getValueGroup(name string, t reflect.Type) []keyedGroupValue // Retrieves all decorated values for the provided group and type, if any. getDecoratedValueGroup(name string, t reflect.Type) (reflect.Value, bool) @@ -273,8 +273,8 @@ func (bs byTypeName) Swap(i int, j int) { bs[i], bs[j] = bs[j], bs[i] } -func shuffledCopy(rand *rand.Rand, items []reflect.Value) []reflect.Value { - newItems := make([]reflect.Value, len(items)) +func shuffledCopy(rand *rand.Rand, items []keyedGroupValue) []keyedGroupValue { + newItems := make([]keyedGroupValue, len(items)) for i, j := range rand.Perm(len(items)) { newItems[i] = items[j] } diff --git a/param.go b/param.go index d584fc23..e60501e1 100644 --- a/param.go +++ b/param.go @@ -627,6 +627,7 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { } // Check if we have decorated values + // qjeremy(how to handle this with maps?) if decoratedItems, ok := pt.getDecoratedValues(c); ok { return decoratedItems, nil } @@ -645,7 +646,10 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { stores := c.storesToRoot() result := reflect.MakeSlice(pt.Type, 0, itemCount) for _, c := range stores { - result = reflect.Append(result, c.getValueGroup(pt.Group, pt.Type.Elem())...) + kgvs := c.getValueGroup(pt.Group, pt.Type.Elem()) + for _, kgv := range kgvs { + result = reflect.Append(result, kgv.value) + } } return result, nil } diff --git a/result.go b/result.go index a4e5ba38..22c537dd 100644 --- a/result.go +++ b/result.go @@ -491,6 +491,9 @@ type resultGrouped struct { // Name of the group as specified in the `group:".."` tag. Group string + // Key if a name tag or option was provided, for populating maps + Key string + // Type of value produced. Type reflect.Type @@ -527,8 +530,10 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { if err != nil { return resultGrouped{}, err } + name := f.Tag.Get(_nameTag) rg := resultGrouped{ Group: g.Name, + Key: name, Flatten: g.Flatten, Type: f.Type, } @@ -553,18 +558,19 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { func (rt resultGrouped) Extract(cw containerWriter, decorated bool, v reflect.Value) { // Decorated values are always flattened. if !decorated && !rt.Flatten { - cw.submitGroupedValue(rt.Group, rt.Type, v) + cw.submitGroupedValue(rt.Group, rt.Key, rt.Type, v) for _, asType := range rt.As { - cw.submitGroupedValue(rt.Group, asType, v) + cw.submitGroupedValue(rt.Group, rt.Key, asType, v) } return } if decorated { - cw.submitDecoratedGroupedValue(rt.Group, rt.Type, v) + cw.submitDecoratedGroupedValue(rt.Group, rt.Key, rt.Type, v) return } + // it's not possible to provide a key for the flattening case for i := 0; i < v.Len(); i++ { - cw.submitGroupedValue(rt.Group, rt.Type, v.Index(i)) + cw.submitGroupedValue(rt.Group, "", rt.Type, v.Index(i)) } } diff --git a/result_test.go b/result_test.go index db58ce2b..974e9d8c 100644 --- a/result_test.go +++ b/result_test.go @@ -196,7 +196,7 @@ func TestNewResultObject(t *testing.T) { FieldName: "Writer", FieldIndex: 1, Results: []result{ - resultGrouped{Group: "writers", Type: typeOfWriter}, + resultGrouped{Group: "writers", Key: "writer1", Type: typeOfWriter}, resultSingle{Name: "writer1", Type: typeOfWriter}, }, }, diff --git a/scope.go b/scope.go index 216cf18a..d28806b2 100644 --- a/scope.go +++ b/scope.go @@ -35,6 +35,11 @@ type ScopeOption interface { noScopeOption() //yet } +type keyedGroupValue struct { + key string + value reflect.Value +} + // Scope is a scoped DAG of types and their dependencies. // A Scope may also have one or more child Scopes that inherit // from it. @@ -61,10 +66,10 @@ type Scope struct { values map[key]reflect.Value // Values groups that generated directly in the Scope. - groups map[key][]reflect.Value + groups map[key][]keyedGroupValue // Values groups that generated via decoraters in the Scope. - decoratedGroups map[key]reflect.Value + decoratedGroups map[key]keyedGroupValue // Source of randomness. rand *rand.Rand @@ -98,8 +103,8 @@ func newScope() *Scope { decorators: make(map[key]*decoratorNode), values: make(map[key]reflect.Value), decoratedValues: make(map[key]reflect.Value), - groups: make(map[key][]reflect.Value), - decoratedGroups: make(map[key]reflect.Value), + groups: make(map[key][]keyedGroupValue), + decoratedGroups: make(map[key]keyedGroupValue), invokerFn: defaultInvoker, rand: rand.New(rand.NewSource(time.Now().UnixNano())), } @@ -190,7 +195,7 @@ func (s *Scope) setDecoratedValue(name string, t reflect.Type, v reflect.Value) s.decoratedValues[key{name: name, t: t}] = v } -func (s *Scope) getValueGroup(name string, t reflect.Type) []reflect.Value { +func (s *Scope) getValueGroup(name string, t reflect.Type) []keyedGroupValue { items := s.groups[key{group: name, t: t}] // shuffle the list so users don't rely on the ordering of grouped values return shuffledCopy(s.rand, items) @@ -198,17 +203,17 @@ func (s *Scope) getValueGroup(name string, t reflect.Type) []reflect.Value { func (s *Scope) getDecoratedValueGroup(name string, t reflect.Type) (reflect.Value, bool) { items, ok := s.decoratedGroups[key{group: name, t: t}] - return items, ok + return items.value, ok } -func (s *Scope) submitGroupedValue(name string, t reflect.Type, v reflect.Value) { +func (s *Scope) submitGroupedValue(name, mapKey string, t reflect.Type, v reflect.Value) { k := key{group: name, t: t} - s.groups[k] = append(s.groups[k], v) + s.groups[k] = append(s.groups[k], keyedGroupValue{key: mapKey, value: v}) } -func (s *Scope) submitDecoratedGroupedValue(name string, t reflect.Type, v reflect.Value) { +func (s *Scope) submitDecoratedGroupedValue(name, mapKey string, t reflect.Type, v reflect.Value) { k := key{group: name, t: t} - s.decoratedGroups[k] = v + s.decoratedGroups[k] = keyedGroupValue{key: mapKey, value: v} } func (s *Scope) getValueProviders(name string, t reflect.Type) []provider { @@ -310,9 +315,9 @@ func (s *Scope) String() string { for k, v := range s.values { fmt.Fprintln(b, "\t", k, "=>", v) } - for k, vs := range s.groups { - for _, v := range vs { - fmt.Fprintln(b, "\t", k, "=>", v) + for k, kgvs := range s.groups { + for _, kgv := range kgvs { + fmt.Fprintln(b, "\t", k, "=>", kgv.value) } } fmt.Fprintln(b, "}") From 23338b1e2922079f1c8a9bc3efd8de8adcd8c2ac Mon Sep 17 00:00:00 2001 From: Jeremy Quirke Date: Mon, 6 Mar 2023 22:04:58 -0800 Subject: [PATCH 3/3] Support map value groups This revision allows dig to specify value groups of map type. For example: ``` type Params struct { dig.In Things []int `group:"foogroup"` MapOfThings map[string]int `group:"foogroup"` } type Result struct { dig.Out Int1 int `name:"foo1" group:"foogroup"` Int2 int `name:"foo2" group:"foogroup"` Int3 int `name:"foo3" group:"foogroup"` } c.Provide(func() Result { return Result{Int1: 1, Int2: 2, Int3: 3} }) c.Invoke(func(p Params) { }) ``` p.Things will be a value group slice as per usual, containing the elements {1,2,3} in an arbitrary order. p.MapOfThings will be a key-value pairing of {"foo1":1, "foo2":2, "foo3":3}. --- decorate.go | 4 +- decorate_test.go | 99 ++++++++++++++++++++++++++++++ dig_test.go | 155 ++++++++++++++++++++++++++++++++++++++++++++++- graph.go | 4 +- param.go | 69 ++++++++++++++------- param_test.go | 18 ++++-- result.go | 2 +- 7 files changed, 319 insertions(+), 32 deletions(-) diff --git a/decorate.go b/decorate.go index 3a7114c5..c57e817a 100644 --- a/decorate.go +++ b/decorate.go @@ -282,7 +282,9 @@ func findResultKeys(r resultList) ([]key, error) { case resultSingle: keys = append(keys, key{t: innerResult.Type, name: innerResult.Name}) case resultGrouped: - if innerResult.Type.Kind() != reflect.Slice { + isMap := innerResult.Type.Kind() == reflect.Map && innerResult.Type.Key().Kind() == reflect.String + isSlice := innerResult.Type.Kind() == reflect.Slice + if !isMap && !isSlice { return nil, newErrInvalidInput("decorating a value group requires decorating the entire value group, not a single value", nil) } keys = append(keys, key{t: innerResult.Type.Elem(), group: innerResult.Group}) diff --git a/decorate_test.go b/decorate_test.go index 721ba801..2bec237c 100644 --- a/decorate_test.go +++ b/decorate_test.go @@ -215,6 +215,64 @@ func TestDecorateSuccess(t *testing.T) { })) }) + t.Run("map is treated as an ordinary dependency without group tag, named or unnamed, and passes through multiple scopes", func(t *testing.T) { + type params struct { + dig.In + + Strings1 map[string]string + Strings2 map[string]string `name:"strings2"` + } + + type childResult struct { + dig.Out + + Strings1 map[string]string + Strings2 map[string]string `name:"strings2"` + } + + type A map[string]string + type B map[string]string + + parent := digtest.New(t) + parent.RequireProvide(func() map[string]string { return map[string]string{"key1": "val1", "key2": "val2"} }) + parent.RequireProvide(func() map[string]string { return map[string]string{"key1": "val21", "key2": "val22"} }, dig.Name("strings2")) + + parent.RequireProvide(func(p params) A { return A(p.Strings1) }) + parent.RequireProvide(func(p params) B { return B(p.Strings2) }) + + child := parent.Scope("child") + + parent.RequireDecorate(func(p params) childResult { + res := childResult{Strings1: make(map[string]string, len(p.Strings1))} + for k, s := range p.Strings1 { + res.Strings1[k] = strings.ToUpper(s) + } + res.Strings2 = p.Strings2 + return res + }) + + child.RequireDecorate(func(p params) childResult { + res := childResult{Strings2: make(map[string]string, len(p.Strings2))} + for k, s := range p.Strings2 { + res.Strings2[k] = strings.ToUpper(s) + } + res.Strings1 = p.Strings1 + res.Strings1["key3"] = "newval" + return res + }) + + require.NoError(t, child.Invoke(func(p params) { + require.Len(t, p.Strings1, 3) + assert.Equal(t, "VAL1", p.Strings1["key1"]) + assert.Equal(t, "VAL2", p.Strings1["key2"]) + assert.Equal(t, "newval", p.Strings1["key3"]) + require.Len(t, p.Strings2, 2) + assert.Equal(t, "VAL21", p.Strings2["key1"]) + assert.Equal(t, "VAL22", p.Strings2["key2"]) + + })) + + }) t.Run("decorate values in soft group", func(t *testing.T) { type params struct { dig.In @@ -393,6 +451,46 @@ func TestDecorateSuccess(t *testing.T) { assert.Equal(t, `[]string[group = "animals"]`, info.Inputs[0].String()) }) + t.Run("decorate with map value groups", func(t *testing.T) { + type Params struct { + dig.In + + Animals map[string]string `group:"animals"` + } + + type Result struct { + dig.Out + + Animals map[string]string `group:"animals"` + } + + c := digtest.New(t) + c.RequireProvide(func() string { return "dog" }, dig.Name("animal1"), dig.Group("animals")) + c.RequireProvide(func() string { return "cat" }, dig.Name("animal2"), dig.Group("animals")) + c.RequireProvide(func() string { return "gopher" }, dig.Name("animal3"), dig.Group("animals")) + + var info dig.DecorateInfo + c.RequireDecorate(func(p Params) Result { + animals := p.Animals + for k, v := range animals { + animals[k] = "good " + v + } + return Result{ + Animals: animals, + } + }, dig.FillDecorateInfo(&info)) + + c.RequireInvoke(func(p Params) { + assert.Len(t, p.Animals, 3) + assert.Equal(t, "good dog", p.Animals["animal1"]) + assert.Equal(t, "good cat", p.Animals["animal2"]) + assert.Equal(t, "good gopher", p.Animals["animal3"]) + }) + + require.Equal(t, 1, len(info.Inputs)) + assert.Equal(t, `map[string]string[group = "animals"]`, info.Inputs[0].String()) + }) + t.Run("decorate with optional parameter", func(t *testing.T) { c := digtest.New(t) @@ -918,6 +1016,7 @@ func TestMultipleDecorates(t *testing.T) { assert.ElementsMatch(t, []int{2, 3, 4}, a.Values) }) }) + } func TestFillDecorateInfoString(t *testing.T) { diff --git a/dig_test.go b/dig_test.go index 72381c43..adab0cfc 100644 --- a/dig_test.go +++ b/dig_test.go @@ -1241,6 +1241,27 @@ func TestGroups(t *testing.T) { }) }) + t.Run("provide multiple with the same name and group but different type", func(t *testing.T) { + c := digtest.New(t) + type A struct{} + type B struct{} + type ret1 struct { + dig.Out + *A `name:"foo" group:"foos"` + } + type ret2 struct { + dig.Out + *B `name:"foo" group:"foos"` + } + c.RequireProvide(func() ret1 { + return ret1{A: &A{}} + }) + + c.RequireProvide(func() ret2 { + return ret2{B: &B{}} + }) + }) + t.Run("different types may be grouped", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) @@ -1745,6 +1766,118 @@ func TestGroups(t *testing.T) { assert.ElementsMatch(t, []string{"a"}, param.Value) }) }) + /* map tests */ + t.Run("empty map received without provides", func(t *testing.T) { + c := digtest.New(t) + + type in struct { + dig.In + + Values map[string]int `group:"foo"` + } + + c.RequireInvoke(func(i in) { + require.Empty(t, i.Values) + }) + }) + + t.Run("map value group using dig.Name and dig.Group", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + c.RequireProvide(func() int { + return 1 + }, dig.Name("value1"), dig.Group("val")) + c.RequireProvide(func() int { + return 2 + }, dig.Name("value2"), dig.Group("val")) + c.RequireProvide(func() int { + return 3 + }, dig.Name("value3"), dig.Group("val")) + + type in struct { + dig.In + + Value1 int `name:"value1"` + Value2 int `name:"value2"` + Value3 int `name:"value3"` + Values []int `group:"val"` + ValueMap map[string]int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + assert.Equal(t, i.ValueMap["value1"], 1) + assert.Equal(t, i.ValueMap["value2"], 2) + assert.Equal(t, i.ValueMap["value3"], 3) + assert.Equal(t, i.Value1, 1) + assert.Equal(t, i.Value2, 2) + assert.Equal(t, i.Value3, 3) + }) + }) + t.Run("values are provided, map and name and slice", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + type out struct { + dig.Out + + Value1 int `name:"value1" group:"val"` + Value2 int `name:"value2" group:"val"` + Value3 int `name:"value3" group:"val"` + } + + c.RequireProvide(func() out { + return out{Value1: 1, Value2: 2, Value3: 3} + }) + + type in struct { + dig.In + + Value1 int `name:"value1"` + Value2 int `name:"value2"` + Value3 int `name:"value3"` + Values []int `group:"val"` + ValueMap map[string]int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + assert.Equal(t, i.ValueMap["value1"], 1) + assert.Equal(t, i.ValueMap["value2"], 2) + assert.Equal(t, i.ValueMap["value3"], 3) + assert.Equal(t, i.Value1, 1) + assert.Equal(t, i.Value2, 2) + assert.Equal(t, i.Value3, 3) + }) + }) + + t.Run("Every item used in a map must have a named key", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out struct { + dig.Out + + Value1 int `name:"value1" group:"val"` + Value2 int `name:"value2" group:"val"` + Value3 int `group:"val"` + } + + c.RequireProvide(func() out { + return out{Value1: 1, Value2: 2, Value3: 3} + }) + + type in struct { + dig.In + + ValueMap map[string]int `group:"val"` + } + var called = false + err := c.Invoke(func(i in) { called = true }) + dig.AssertErrorMatches(t, err, + `could not build arguments for function "go.uber.org/dig_test".TestGroups\S+`, + `dig_test.go:\d+`, // file:line + `every entry in a map value groups must have a name, group "val" is missing a name`) + assert.False(t, called, "shouldn't call invoked function when deps aren't available") + }) + } // --- END OF END TO END TESTS @@ -2753,7 +2886,27 @@ func testProvideFailures(t *testing.T, dryRun bool) { ) }) - t.Run("provide multiple instances with the same name but different group", func(t *testing.T) { + t.Run("provide multiple instances with the same name and same group using options", func(t *testing.T) { + c := digtest.New(t, dig.DryRun(dryRun)) + type A struct{} + + c.RequireProvide(func() *A { + return &A{} + }, dig.Group("foos"), dig.Name("foo")) + + err := c.Provide(func() *A { + return &A{} + }, dig.Group("foos"), dig.Name("foo")) + require.Error(t, err, "expected error on the second provide") + dig.AssertErrorMatches(t, err, + `cannot provide function "go.uber.org/dig_test".testProvideFailures\S+`, + `dig_test.go:\d+`, // file:line + `cannot provide \*dig_test.A\[name="foo"\] from \[1\]:`, + `already provided by "go.uber.org/dig_test".testProvideFailures\S+`, + ) + }) + + t.Run("provide multiple instances with the same name and type but different group", func(t *testing.T) { c := digtest.New(t, dig.DryRun(dryRun)) type A struct{} type ret1 struct { diff --git a/graph.go b/graph.go index e08f1f54..c52e78f7 100644 --- a/graph.go +++ b/graph.go @@ -28,7 +28,7 @@ type graphNode struct { } // graphHolder is the dependency graph of the container. -// It saves constructorNodes and paramGroupedSlice (value groups) +// It saves constructorNodes and paramGroupedCollection (value groups) // as nodes in the graph. // It implements the graph interface defined by internal/graph. // It has 1-1 correspondence with the Scope whose graph it represents. @@ -68,7 +68,7 @@ func (gh *graphHolder) EdgesFrom(u int) []int { for _, param := range w.paramList.Params { orders = append(orders, getParamOrder(gh, param)...) } - case *paramGroupedSlice: + case *paramGroupedCollection: providers := gh.s.getAllGroupProviders(w.Group, w.Type.Elem()) for _, provider := range providers { orders = append(orders, provider.Order(gh.s)) diff --git a/param.go b/param.go index e60501e1..e44b3b5c 100644 --- a/param.go +++ b/param.go @@ -38,10 +38,12 @@ import ( // paramSingle An explicitly requested type. // paramObject dig.In struct where each field in the struct can be another // param. -// paramGroupedSlice -// A slice consuming a value group. This will receive all +// paramGroupedCollection +// A slice or map consuming a value group. This will receive all // values produced with a `group:".."` tag with the same name -// as a slice. +// as a slice or map. For a map, every value produced with the +// same group name MUST have a name which will form the map key. + type param interface { fmt.Stringer @@ -59,7 +61,7 @@ var ( _ param = paramSingle{} _ param = paramObject{} _ param = paramList{} - _ param = paramGroupedSlice{} + _ param = paramGroupedCollection{} ) // newParam builds a param from the given type. If the provided type is a @@ -342,7 +344,7 @@ func getParamOrder(gh *graphHolder, param param) []int { for _, provider := range providers { orders = append(orders, provider.Order(gh.s)) } - case paramGroupedSlice: + case paramGroupedCollection: // value group parameters have nodes of their own. // We can directly return that here. orders = append(orders, p.orders[gh.s]) @@ -401,7 +403,7 @@ func (po paramObject) Build(c containerStore) (reflect.Value, error) { var softGroupsQueue []paramObjectField var fields []paramObjectField for _, f := range po.Fields { - if p, ok := f.Param.(paramGroupedSlice); ok && p.Soft { + if p, ok := f.Param.(paramGroupedCollection); ok && p.Soft { softGroupsQueue = append(softGroupsQueue, f) continue } @@ -451,7 +453,7 @@ func newParamObjectField(idx int, f reflect.StructField, c containerStore) (para case f.Tag.Get(_groupTag) != "": var err error - p, err = newParamGroupedSlice(f, c) + p, err = newParamGroupedCollection(f, c) if err != nil { return pof, err } @@ -488,13 +490,13 @@ func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) { return v, nil } -// paramGroupedSlice is a param which produces a slice of values with the same +// paramGroupedCollection is a param which produces a slice or map of values with the same // group name. -type paramGroupedSlice struct { +type paramGroupedCollection struct { // Name of the group as specified in the `group:".."` tag. Group string - // Type of the slice. + // Type of the map or slice. Type reflect.Type // Soft is used to denote a soft dependency between this param and its @@ -502,15 +504,17 @@ type paramGroupedSlice struct { // provide another value requested in the graph Soft bool + isMap bool orders map[*Scope]int } -func (pt paramGroupedSlice) String() string { +func (pt paramGroupedCollection) String() string { // io.Reader[group="foo"] refers to a group of io.Readers called 'foo' return fmt.Sprintf("%v[group=%q]", pt.Type.Elem(), pt.Group) + // JQTODO, different string for map } -func (pt paramGroupedSlice) DotParam() []*dot.Param { +func (pt paramGroupedCollection) DotParam() []*dot.Param { return []*dot.Param{ { Node: &dot.Node{ @@ -521,18 +525,21 @@ func (pt paramGroupedSlice) DotParam() []*dot.Param { } } -// newParamGroupedSlice builds a paramGroupedSlice from the provided type with +// newParamGroupedCollection builds a paramGroupedCollection from the provided type with // the given name. // -// The type MUST be a slice type. -func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGroupedSlice, error) { +// The type MUST be a slice or map[string]T type. +func newParamGroupedCollection(f reflect.StructField, c containerStore) (paramGroupedCollection, error) { g, err := parseGroupString(f.Tag.Get(_groupTag)) if err != nil { - return paramGroupedSlice{}, err + return paramGroupedCollection{}, err } - pg := paramGroupedSlice{ + isMap := f.Type.Kind() == reflect.Map && f.Type.Key().Kind() == reflect.String + isSlice := f.Type.Kind() == reflect.Slice + pg := paramGroupedCollection{ Group: g.Name, Type: f.Type, + isMap: isMap, orders: make(map[*Scope]int), Soft: g.Soft, } @@ -540,9 +547,9 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { - case f.Type.Kind() != reflect.Slice: + case !isMap && !isSlice: return pg, newErrInvalidInput( - fmt.Sprintf("value groups may be consumed as slices only: field %q (%v) is not a slice", f.Name, f.Type), nil) + fmt.Sprintf("value groups may be consumed as slices or string-keyed maps only: field %q (%v) is not a slice or string-keyed map", f.Name, f.Type), nil) case g.Flatten: return pg, newErrInvalidInput( fmt.Sprintf("cannot use flatten in parameter value groups: field %q (%v) specifies flatten", f.Name, f.Type), nil) @@ -560,7 +567,7 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped // any of the parent Scopes. In the case where there are multiple scopes that // are decorating the same type, the closest scope in effect will be replacing // any decorated value groups provided in further scopes. -func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, bool) { +func (pt paramGroupedCollection) getDecoratedValues(c containerStore) (reflect.Value, bool) { for _, c := range c.storesToRoot() { if items, ok := c.getDecoratedValueGroup(pt.Group, pt.Type); ok { return items, true @@ -575,7 +582,7 @@ func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, // The order in which the decorators are invoked is from the top level scope to // the current scope, to account for decorators that decorate values that were // already decorated. -func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { +func (pt paramGroupedCollection) callGroupDecorators(c containerStore) error { stores := c.storesToRoot() for i := len(stores) - 1; i >= 0; i-- { c := stores[i] @@ -600,7 +607,7 @@ func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { // search the given container and its parent for matching group providers and // call them to commit values. If an error is encountered, return the number // of providers called and a non-nil error from the first provided. -func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { +func (pt paramGroupedCollection) callGroupProviders(c containerStore) (int, error) { itemCount := 0 for _, c := range c.storesToRoot() { providers := c.getGroupProviders(pt.Group, pt.Type.Elem()) @@ -618,7 +625,7 @@ func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { return itemCount, nil } -func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { +func (pt paramGroupedCollection) Build(c containerStore) (reflect.Value, error) { // do not call this if we are already inside a decorator since // it will result in an infinite recursion. (i.e. decorate -> params.BuildList() -> Decorate -> params.BuildList...) // this is safe since a value can be decorated at most once in a given scope. @@ -644,6 +651,22 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { } stores := c.storesToRoot() + if pt.isMap { + result := reflect.MakeMapWithSize(pt.Type, itemCount) + for _, c := range stores { + kgvs := c.getValueGroup(pt.Group, pt.Type.Elem()) + for _, kgv := range kgvs { + if kgv.key == "" { + return _noValue, newErrInvalidInput( + fmt.Sprintf("every entry in a map value groups must have a name, group \"%v\" is missing a name", pt.Group), + nil, + ) + } + result.SetMapIndex(reflect.ValueOf(kgv.key), kgv.value) + } + } + return result, nil + } result := reflect.MakeSlice(pt.Type, 0, itemCount) for _, c := range stores { kgvs := c.getValueGroup(pt.Group, pt.Type.Elem()) diff --git a/param_test.go b/param_test.go index 7a1f41ed..ce9e5c4d 100644 --- a/param_test.go +++ b/param_test.go @@ -179,21 +179,31 @@ func TestParamObjectFailure(t *testing.T) { }) } -func TestParamGroupSliceErrors(t *testing.T) { +func TestParamGroupCollectionErrors(t *testing.T) { tests := []struct { desc string shape interface{} wantErr string }{ { - desc: "non-slice type are disallowed", + desc: "non-slice or string-keyed map type are disallowed (slice)", shape: struct { In Foo string `group:"foo"` }{}, - wantErr: "value groups may be consumed as slices only: " + - `field "Foo" (string) is not a slice`, + wantErr: "value groups may be consumed as slices or string-keyed maps only: " + + `field "Foo" (string) is not a slice or string-keyed map`, + }, + { + desc: "non-slice or string-keyed map type are disallowed (string-keyed map)", + shape: struct { + In + + Foo map[int]int `group:"foo"` + }{}, + wantErr: "value groups may be consumed as slices or string-keyed maps only: " + + `field "Foo" (map[int]int) is not a slice or string-keyed map`, }, { desc: "cannot provide name for a group", diff --git a/result.go b/result.go index 22c537dd..24eb7d8d 100644 --- a/result.go +++ b/result.go @@ -87,7 +87,7 @@ func newResult(t reflect.Type, opts resultOptions, noGroup bool) (result, error) return nil, newErrInvalidInput( fmt.Sprintf("cannot parse group %q", opts.Group), err) } - rg := resultGrouped{Type: t, Group: g.Name, Flatten: g.Flatten} + rg := resultGrouped{Type: t, Key: opts.Name, Group: g.Name, Flatten: g.Flatten} if len(opts.As) > 0 { var asTypes []reflect.Type for _, as := range opts.As {