diff --git a/dig.go b/dig.go index 36df0873..2607e7db 100644 --- a/dig.go +++ b/dig.go @@ -62,12 +62,19 @@ func (f optionFunc) applyOption(c *Container) { f(c) } type provideOptions struct { Name string Group string + As []interface{} } func (o *provideOptions) Validate() error { - if len(o.Group) > 0 && len(o.Name) > 0 { - return fmt.Errorf( - "cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group) + if len(o.Group) > 0 { + if len(o.Name) > 0 { + return fmt.Errorf( + "cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group) + } + if len(o.As) > 0 { + return fmt.Errorf( + "cannot use dig.As with value groups: dig.As provided with group:%q", o.Group) + } } // Names must be representable inside a backquoted string. The only @@ -80,6 +87,23 @@ func (o *provideOptions) Validate() error { if strings.ContainsRune(o.Group, '`') { return fmt.Errorf("invalid dig.Group(%q): group names cannot contain backquotes", o.Group) } + + for _, i := range o.As { + t := reflect.TypeOf(i) + + if t == nil { + return fmt.Errorf("invalid dig.As(nil): argument must be a pointer to an interface") + } + + if t.Kind() != reflect.Ptr { + return fmt.Errorf("invalid dig.As(%v): argument must be a pointer to an interface", t) + } + + pointingTo := t.Elem() + if pointingTo.Kind() != reflect.Interface { + return fmt.Errorf("invalid dig.As(*%v): argument must be a pointer to an interface", pointingTo) + } + } return nil } @@ -127,6 +151,55 @@ func Group(group string) ProvideOption { }) } +// As is a ProvideOption that specifies that the value produced by the +// constructor implements one or more other interfaces. +// +// As expects one or more pointers to the implemented interfaces. Values +// produced by constructors will be made available in the container as +// implementations of all of those interfaces. +// +// For example, the following will make the buffer available in the container +// as io.Reader and io.Writer. +// +// c.Provide(newBuffer, dig.As(new(io.Reader), new(io.Writer))) +// +// That is, the above is equivalent to the following. +// +// c.Provide(func(...) (*bytes.Buffer, io.Reader, io.Writer) { +// b := newBuffer(...) +// return b, b, b +// }) +// +// If used with dig.Name, the type produced by the constructor and the types +// specified with dig.As will all use the same name. For example, +// +// c.Provide(newFile, dig.As(new(io.Reader)), dig.Name("temp")) +// +// The above is equivalent to the following. +// +// type Result struct { +// dig.Out +// +// File *os.File `name:"temp"` +// Reader io.Reader `name:"temp"` +// } +// +// c.Provide(func(...) Result { +// f := newFile(...) +// return Result{ +// File: f, +// Reader: f, +// } +// }) +// +// This option cannot be provided for constructors which produce result +// objects. +func As(i ...interface{}) ProvideOption { + return provideOptionFunc(func(opts *provideOptions) { + opts.As = append(opts.As, i...) + }) +} + // An InvokeOption modifies the default behavior of Invoke. It's included for // future functionality; currently, there are no concrete implementations. type InvokeOption interface { @@ -424,6 +497,7 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) error { nodeOptions{ ResultName: opts.Name, ResultGroup: opts.Group, + ResultAs: opts.As, }, ) if err != nil { @@ -535,30 +609,23 @@ func (cv connectionVisitor) Visit(res result) resultVisitor { path := strings.Join(cv.currentResultPath, ".") switch r := res.(type) { + case resultSingle: k := key{name: r.Name, t: r.Type} - - if conflict, ok := cv.keyPaths[k]; ok { - *cv.err = fmt.Errorf( - "cannot provide %v from %v: already provided by %v", - k, path, conflict) + if err := cv.checkKey(k, path); err != nil { + *cv.err = err return nil } - - if ps := cv.c.providers[k]; len(ps) > 0 { - cons := make([]string, len(ps)) - for i, p := range ps { - cons[i] = fmt.Sprint(p.Location()) + cv.keyPaths[k] = path + for _, asType := range r.As { + k := key{name: r.Name, t: asType} + if err := cv.checkKey(k, path); err != nil { + *cv.err = err + return nil } - - *cv.err = fmt.Errorf( - "cannot provide %v from %v: already provided by %v", - k, path, strings.Join(cons, "; ")) - return nil + cv.keyPaths[k] = path } - cv.keyPaths[k] = path - case resultGrouped: // we don't really care about the path for this since conflicts are // okay for group results. We'll track it for the sake of having a @@ -570,6 +637,25 @@ func (cv connectionVisitor) Visit(res result) resultVisitor { return cv } +func (cv connectionVisitor) checkKey(k key, path string) error { + if conflict, ok := cv.keyPaths[k]; ok { + return fmt.Errorf( + "cannot provide %v from %v: already provided by %v", + k, path, conflict) + } + if ps := cv.c.providers[k]; len(ps) > 0 { + cons := make([]string, len(ps)) + for i, p := range ps { + cons[i] = fmt.Sprint(p.Location()) + } + + return fmt.Errorf( + "cannot provide %v from %v: already provided by %v", + k, path, strings.Join(cons, "; ")) + } + return nil +} + // node is a node in the dependency graph. Each node maps to a single // constructor provided by the user. // @@ -598,9 +684,10 @@ type node struct { type nodeOptions struct { // If specified, all values produced by this node have the provided name - // or belong to the specified value group + // belong to the specified value group or implement any of the interfaces. ResultName string ResultGroup string + ResultAs []interface{} } func newNode(ctor interface{}, opts nodeOptions) (*node, error) { @@ -618,6 +705,7 @@ func newNode(ctor interface{}, opts nodeOptions) (*node, error) { resultOptions{ Name: opts.ResultName, Group: opts.ResultGroup, + As: opts.ResultAs, }, ) if err != nil { diff --git a/dig_test.go b/dig_test.go index 6790d13e..82b519bd 100644 --- a/dig_test.go +++ b/dig_test.go @@ -65,12 +65,14 @@ func TestEndToEndSuccess(t *testing.T) { t.Run("struct constructor", func(t *testing.T) { c := New() - var buf bytes.Buffer - buf.WriteString("foo") - require.NoError(t, c.Provide(func() bytes.Buffer { return buf }), "provide failed") + require.NoError(t, c.Provide(func() bytes.Buffer { + var buf bytes.Buffer + buf.WriteString("foo") + return buf + }), "provide failed") require.NoError(t, c.Invoke(func(b bytes.Buffer) { // ensure we're getting back the buffer we put in - require.Equal(t, "foo", buf.String(), "invoke got new buffer") + require.Equal(t, "foo", b.String(), "invoke got new buffer") }), "invoke failed") }) @@ -603,6 +605,74 @@ func TestEndToEndSuccess(t *testing.T) { }), "both objects should be successfully resolved on Invoke") }) + t.Run("struct constructor with as interface option", func(t *testing.T) { + c := New() + + provider := c.Provide( + func() *bytes.Buffer { + var buf bytes.Buffer + buf.WriteString("foo") + return &buf + }, + As(new(fmt.Stringer), new(io.Reader)), + ) + + require.NoError(t, provider, "provide failed") + + require.NoError(t, c.Invoke( + func(s fmt.Stringer, r io.Reader) { + require.Equal(t, "foo", s.String(), "invoke got new buffer") + got, err := ioutil.ReadAll(r) + assert.NoError(t, err, "failed to read from reader") + require.Equal(t, "foo", string(got), "invoke got new buffer") + }, + ), "invoke failed") + }) + + t.Run("As with Name", func(t *testing.T) { + c := New() + + require.NoError(t, c.Provide( + func() *bytes.Buffer { + return bytes.NewBufferString("foo") + }, + As(new(io.Reader)), + Name("buff"), + ), "failed to provide") + + type in struct { + In + + Buffer *bytes.Buffer `name:"buff"` + Reader io.Reader `name:"buff"` + } + + require.NoError(t, c.Invoke(func(got in) { + assert.NotNil(t, got.Buffer, "buffer must not be nil") + + assert.True(t, got.Buffer == got.Reader, + "reader and buffer must be the same object") + + body, err := ioutil.ReadAll(got.Reader) + require.NoError(t, err, "failed to read buffer body") + assert.Equal(t, "foo", string(body)) + })) + }) + + t.Run("As same interface", func(t *testing.T) { + c := New() + require.NoError(t, c.Provide(func() io.Reader { + panic("this function should not be called") + }, As(new(io.Reader))), "failed to provide") + }) + + t.Run("As different interface", func(t *testing.T) { + c := New() + require.NoError(t, c.Provide(func() io.ReadCloser { + panic("this function should not be called") + }, As(new(io.Reader), new(io.Closer))), "failed to provide") + }) + t.Run("invoke on a type that depends on named parameters", func(t *testing.T) { c := New() type A struct{ idx int } @@ -1446,16 +1516,83 @@ func TestProvideInvalidGroup(t *testing.T) { assert.Contains(t, err.Error(), "invalid dig.Group(\"foo`bar\"): group names cannot contain backquotes") } -func TestProvideGroupAndName(t *testing.T) { +func TestProvideInvalidAs(t *testing.T) { + ptrToStruct := &struct { + name string + }{ + name: "example", + } + var nilInterface io.Reader + c := New() + tests := []struct { + name string + param interface{} + expectedErr string + }{ + { + name: "as param is not an type interface", + param: 123, + expectedErr: "invalid dig.As(int): argument must be a pointer to an interface", + }, + { + name: "as param is a pointer to struct", + param: ptrToStruct, + expectedErr: "invalid dig.As(*struct { name string }): argument must be a pointer to an interface", + }, + { + name: "as param is a nil interface", + param: nilInterface, + expectedErr: "invalid dig.As(nil): argument must be a pointer to an interface", + }, + { + name: "as param is a nil", + param: nil, + expectedErr: "invalid dig.As(nil): argument must be a pointer to an interface", + }, + { + name: "as param is a func", + param: func() {}, + expectedErr: "invalid dig.As(func()): argument must be a pointer to an interface", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := c.Provide( + func() *bytes.Buffer { + var buf bytes.Buffer + return &buf + }, + As(tt.param), + ) + + require.Error(t, err, "provide must fail") + assert.Contains(t, err.Error(), tt.expectedErr) + }) + } +} + +func TestProvideIncompatibleOptions(t *testing.T) { t.Parallel() - c := New() - err := c.Provide(func() io.Reader { - panic("this function must not be called") - }, Group("foo"), Name("bar")) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot use named values with value groups: "+ - "name:\"bar\" provided with group:\"foo\"") + t.Run("group and name", func(t *testing.T) { + c := New() + err := c.Provide(func() io.Reader { + panic("this function must not be called") + }, Group("foo"), Name("bar")) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot use named values with value groups: "+ + "name:\"bar\" provided with group:\"foo\"") + }) + + t.Run("group and As", func(t *testing.T) { + c := New() + err := c.Provide(func() *bytes.Buffer { + panic("this function must not be called") + }, Group("foo"), As(new(io.Reader))) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot use dig.As with value groups: "+ + `dig.As provided with group:"foo"`) + }) } func TestCantProvideUntypedNil(t *testing.T) { @@ -1825,6 +1962,32 @@ func TestProvideFailures(t *testing.T) { ) }) + t.Run("out returning multiple instances of the same type and As option", func(t *testing.T) { + c := New() + type A struct{ idx int } + type ret struct { + Out + + A1 A // same type A provided three times + A2 A + A3 A + } + + err := c.Provide(func() ret { + return ret{ + A1: A{idx: 1}, + A2: A{idx: 2}, + A3: A{idx: 3}, + } + }, As(new(interface{}))) + require.Error(t, err, "provide must return error") + assertErrorMatches(t, err, + `function "go.uber.org/dig".TestProvideFailures\S+ \(\S+:\d+\) cannot be provided:`, + `cannot provide dig.A from \[0\].A2:`, + `already provided by \[0\].A1`, + ) + }) + t.Run("provide multiple instances with the same name", func(t *testing.T) { c := New() type A struct{} @@ -1905,6 +2068,46 @@ func TestProvideFailures(t *testing.T) { `dig.out embeds \*dig.Out`, ) }) + + t.Run("provide the same implemented interface", func(t *testing.T) { + c := New() + err := c.Provide( + func() *bytes.Buffer { + var buf bytes.Buffer + return &buf + }, + As(new(io.Reader)), + As(new(io.Reader)), + ) + + require.Error(t, err, "provide must fail") + assert.Contains(t, err.Error(), "cannot provide io.Reader") + assert.Contains(t, err.Error(), "already provided") + }) + + t.Run("provide the same implementation with as interface", func(t *testing.T) { + c := New() + err := c.Provide( + func() *bytes.Buffer { + var buf bytes.Buffer + return &buf + }, + As(new(io.Reader)), + ) + require.NoError(t, err, "provide must not fail here") + + err = c.Provide( + func() *bytes.Buffer { + var buf bytes.Buffer + return &buf + }, + As(new(io.Reader)), + ) + + require.Error(t, err, "provide must fail") + assert.Contains(t, err.Error(), "cannot provide *bytes.Buffer") + assert.Contains(t, err.Error(), "already provided") + }) } func TestInvokeFailures(t *testing.T) { diff --git a/graph_test.go b/graph_test.go index 154b4edb..92c450e1 100644 --- a/graph_test.go +++ b/graph_test.go @@ -21,8 +21,11 @@ package dig import ( + "bytes" "fmt" + "io" "reflect" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -58,11 +61,14 @@ func TestDotGraph(t *testing.T) { type t2 struct{} type t3 struct{} type t4 struct{} + type t5 strings.Reader type1 := reflect.TypeOf(t1{}) type2 := reflect.TypeOf(t2{}) type3 := reflect.TypeOf(t3{}) type4 := reflect.TypeOf(t4{}) + type5 := reflect.TypeOf(t5{}) + type6 := reflect.Indirect(reflect.ValueOf(new(io.Reader))).Type() p1 := tparam(type1, "", "", false) p2 := tparam(type2, "", "", false) @@ -73,6 +79,8 @@ func TestDotGraph(t *testing.T) { r2 := tresult(type2, "", "", 0) r3 := tresult(type3, "", "", 0) r4 := tresult(type4, "", "", 0) + r5 := tresult(type5, "", "", 0) + r6 := tresult(type6, "", "", 0) t.Parallel() @@ -91,6 +99,21 @@ func TestDotGraph(t *testing.T) { assertCtorsEqual(t, expected, dg.Ctors) }) + t.Run("create graph with one constructor and as interface option", func(t *testing.T) { + expected := []*dot.Ctor{ + { + Params: []*dot.Param{p1}, + Results: []*dot.Result{r5, r6}, + }, + } + + c := New() + c.Provide(func(A t1) t5 { return t5{} }, As(new(io.Reader))) + + dg := c.createGraph() + assertCtorsEqual(t, expected, dg.Ctors) + }) + t.Run("create graph with multple constructors", func(t *testing.T) { expected := []*dot.Ctor{ { @@ -417,6 +440,19 @@ func TestVisualize(t *testing.T) { VerifyVisualization(t, "named", c) }) + t.Run("dig.As two types", func(t *testing.T) { + c := New() + + require.NoError(t, c.Provide( + func() *bytes.Buffer { + panic("this function should not be called") + }, + As(new(io.Reader), new(io.Writer)), + )) + + VerifyVisualization(t, "dig_as_two", c) + }) + t.Run("optional params", func(t *testing.T) { c := New() diff --git a/result.go b/result.go index d113e260..226b8c7a 100644 --- a/result.go +++ b/result.go @@ -37,6 +37,7 @@ import ( // another result. // resultGrouped A value produced by a constructor that is part of a value // group. + type result interface { // Extracts the values for this result from the provided value and // stores them into the provided containerWriter. @@ -61,6 +62,7 @@ type resultOptions struct { // For Result Objects, name:".." tags on fields override this. Name string Group string + As []interface{} } // newResult builds a result from the given type. @@ -83,7 +85,7 @@ func newResult(t reflect.Type, opts resultOptions) (result, error) { case len(opts.Group) > 0: return resultGrouped{Type: t, Group: opts.Group}, nil default: - return resultSingle{Type: t, Name: opts.Name}, nil + return newResultSingle(t, opts) } } @@ -180,14 +182,15 @@ func (rl resultList) DotResult() []*dot.Result { } func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) { + numOut := ctype.NumOut() rl := resultList{ ctype: ctype, - Results: make([]result, 0, ctype.NumOut()), - resultIndexes: make([]int, ctype.NumOut()), + Results: make([]result, 0, numOut), + resultIndexes: make([]int, numOut), } resultIdx := 0 - for i := 0; i < ctype.NumOut(); i++ { + for i := 0; i < numOut; i++ { t := ctype.Out(i) if isError(t) { rl.resultIndexes[i] = -1 @@ -236,21 +239,59 @@ func (rl resultList) ExtractList(cw containerWriter, values []reflect.Value) err type resultSingle struct { Name string Type reflect.Type + + // If specified, this is a list of types which the value will be made + // available as, in addition to its own type. + As []reflect.Type +} + +func newResultSingle(t reflect.Type, opts resultOptions) (resultSingle, error) { + r := resultSingle{ + Type: t, + Name: opts.Name, + } + + for _, as := range opts.As { + ifaceType := reflect.TypeOf(as).Elem() + if !t.Implements(ifaceType) { + return r, fmt.Errorf("invalid dig.As: %v does not implement %v", t, ifaceType) + } + if ifaceType == t { + // Special case: + // c.Provide(func() io.Reader, As(new(io.Reader))) + // Ignore instead of erroring out. + continue + } + r.As = append(r.As, ifaceType) + } + + return r, nil } func (rs resultSingle) DotResult() []*dot.Result { - return []*dot.Result{ - { - Node: &dot.Node{ - Type: rs.Type, - Name: rs.Name, - }, + dotResults := make([]*dot.Result, 0, len(rs.As)+1) + dotResults = append(dotResults, &dot.Result{ + Node: &dot.Node{ + Type: rs.Type, + Name: rs.Name, }, + }) + + for _, asType := range rs.As { + dotResults = append(dotResults, &dot.Result{ + Node: &dot.Node{Type: asType, Name: rs.Name}, + }) } + + return dotResults } func (rs resultSingle) Extract(cw containerWriter, v reflect.Value) { cw.setValue(rs.Name, rs.Type, v) + + for _, asType := range rs.As { + cw.setValue(rs.Name, asType, v) + } } // resultObject is a dig.Out struct where each field is another result. diff --git a/testdata/dig_as_two.dot b/testdata/dig_as_two.dot new file mode 100644 index 00000000..e599c3b2 --- /dev/null +++ b/testdata/dig_as_two.dot @@ -0,0 +1,16 @@ +digraph { + rankdir=RL; + graph [compound=true]; + + subgraph cluster_0 { + constructor_0 [shape=plaintext label="TestVisualize.func4.1"]; + + "*bytes.Buffer" [label=<*bytes.Buffer>]; + "io.Reader" [label=]; + "io.Writer" [label=]; + + } + + + +} \ No newline at end of file diff --git a/testdata/error.dot b/testdata/error.dot index 59f52533..247a2559 100644 --- a/testdata/error.dot +++ b/testdata/error.dot @@ -10,7 +10,7 @@ digraph { subgraph cluster_0 { - constructor_0 [shape=plaintext label="TestVisualize.func6.1"]; + constructor_0 [shape=plaintext label="TestVisualize.func7.1"]; color=orange; "dig.t3[name=n3]" [label=Name: n3>]; "dig.t2[group=g2]0" [label=Group: g2>]; @@ -21,7 +21,7 @@ digraph { constructor_0 -> "[type=dig.t1 group=g1]" [ltail=cluster_0]; subgraph cluster_1 { - constructor_1 [shape=plaintext label="TestVisualize.func6.2"]; + constructor_1 [shape=plaintext label="TestVisualize.func7.2"]; color=orange; "dig.t4" [label=]; @@ -33,7 +33,7 @@ digraph { constructor_1 -> "[type=dig.t2 group=g2]" [ltail=cluster_1]; subgraph cluster_2 { - constructor_2 [shape=plaintext label="TestVisualize.func6.4"]; + constructor_2 [shape=plaintext label="TestVisualize.func7.4"]; color=red; "dig.t1[group=g1]0" [label=Group: g1>]; "dig.t2[group=g2]2" [label=Group: g2>]; diff --git a/testdata/grouped.dot b/testdata/grouped.dot index 320ad611..9f059fe1 100644 --- a/testdata/grouped.dot +++ b/testdata/grouped.dot @@ -7,7 +7,7 @@ digraph { subgraph cluster_0 { - constructor_0 [shape=plaintext label="TestVisualize.func5.1"]; + constructor_0 [shape=plaintext label="TestVisualize.func6.1"]; "dig.t3[group=foo]0" [label=Group: foo>]; @@ -15,7 +15,7 @@ digraph { subgraph cluster_1 { - constructor_1 [shape=plaintext label="TestVisualize.func5.2"]; + constructor_1 [shape=plaintext label="TestVisualize.func6.2"]; "dig.t3[group=foo]1" [label=Group: foo>]; @@ -23,7 +23,7 @@ digraph { subgraph cluster_2 { - constructor_2 [shape=plaintext label="TestVisualize.func5.3"]; + constructor_2 [shape=plaintext label="TestVisualize.func6.3"]; "dig.t2" [label=]; diff --git a/testdata/missing.dot b/testdata/missing.dot index 6329f34e..80e4412a 100644 --- a/testdata/missing.dot +++ b/testdata/missing.dot @@ -3,7 +3,7 @@ digraph { graph [compound=true]; subgraph cluster_0 { - constructor_0 [shape=plaintext label="TestVisualize.func7.1"]; + constructor_0 [shape=plaintext label="TestVisualize.func8.1"]; color=orange; "dig.t4" [label=]; diff --git a/testdata/optional.dot b/testdata/optional.dot index 02b4171d..376467a0 100644 --- a/testdata/optional.dot +++ b/testdata/optional.dot @@ -3,7 +3,7 @@ digraph { graph [compound=true]; subgraph cluster_0 { - constructor_0 [shape=plaintext label="TestVisualize.func4.1"]; + constructor_0 [shape=plaintext label="TestVisualize.func5.1"]; "dig.t1" [label=]; @@ -11,7 +11,7 @@ digraph { subgraph cluster_1 { - constructor_1 [shape=plaintext label="TestVisualize.func4.2"]; + constructor_1 [shape=plaintext label="TestVisualize.func5.2"]; "dig.t2" [label=]; diff --git a/testdata/prune_constructor_result.dot b/testdata/prune_constructor_result.dot index 565bc278..d827252d 100644 --- a/testdata/prune_constructor_result.dot +++ b/testdata/prune_constructor_result.dot @@ -6,7 +6,7 @@ digraph { subgraph cluster_0 { - constructor_0 [shape=plaintext label="TestVisualize.func6.6.1.2"]; + constructor_0 [shape=plaintext label="TestVisualize.func7.6.1.2"]; color=orange; "dig.t4" [label=]; @@ -16,7 +16,7 @@ digraph { constructor_0 -> "[type=dig.t2 group=g2]" [ltail=cluster_0]; subgraph cluster_1 { - constructor_1 [shape=plaintext label="TestVisualize.func6.6.1.3"]; + constructor_1 [shape=plaintext label="TestVisualize.func7.6.1.3"]; color=red; "dig.t2[group=g2]1" [label=Group: g2>]; diff --git a/testdata/prune_non_root_nodes.dot b/testdata/prune_non_root_nodes.dot index 62138c9b..50e9f3d9 100644 --- a/testdata/prune_non_root_nodes.dot +++ b/testdata/prune_non_root_nodes.dot @@ -3,7 +3,7 @@ digraph { graph [compound=true]; subgraph cluster_0 { - constructor_0 [shape=plaintext label="TestVisualize.func6.6.2.2"]; + constructor_0 [shape=plaintext label="TestVisualize.func7.6.2.2"]; color=red; "dig.t4" [label=];