Skip to content

Commit

Permalink
Implement a dig.As ProvideOption
Browse files Browse the repository at this point in the history
This brings back #233.

Per #235 (review),
the issues we need to resolve are,

1. `dig.As` seems to indicate that it's a total override of the provided
   type. The way it currently works is it appends other interfaces on
   top of whatever the constructor already returns
2. semantics of `dig.Provide(func New() (Foo, io.Reader, error), dig.As(new(io.Writer)`.
   It currently fails due to inability to case reader to writer, but
   we'd want some extra validation here. Perhaps `dig.As` is only
   supported for a single type, error tuple?

Closes #197
  • Loading branch information
alessandrozucca authored and abhinav committed Nov 14, 2019
1 parent 4fb70ce commit 287b330
Show file tree
Hide file tree
Showing 11 changed files with 439 additions and 55 deletions.
130 changes: 109 additions & 21 deletions dig.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,19 @@ func (f optionFunc) applyOption(c *Container) { f(c) }
type provideOptions struct {
Name string
Group string
As []interface{}
}

func (o *provideOptions) Validate() error {
if len(o.Group) > 0 && len(o.Name) > 0 {
return fmt.Errorf(
"cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group)
if len(o.Group) > 0 {
if len(o.Name) > 0 {
return fmt.Errorf(
"cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group)
}
if len(o.As) > 0 {
return fmt.Errorf(
"cannot use dig.As with value groups: dig.As provided with group:%q", o.Group)
}
}

// Names must be representable inside a backquoted string. The only
Expand All @@ -80,6 +87,23 @@ func (o *provideOptions) Validate() error {
if strings.ContainsRune(o.Group, '`') {
return fmt.Errorf("invalid dig.Group(%q): group names cannot contain backquotes", o.Group)
}

for _, i := range o.As {
t := reflect.TypeOf(i)

if t == nil {
return fmt.Errorf("invalid dig.As(nil): argument must be a pointer to an interface")
}

if t.Kind() != reflect.Ptr {
return fmt.Errorf("invalid dig.As(%v): argument must be a pointer to an interface", t)
}

pointingTo := t.Elem()
if pointingTo.Kind() != reflect.Interface {
return fmt.Errorf("invalid dig.As(*%v): argument must be a pointer to an interface", pointingTo)
}
}
return nil
}

Expand Down Expand Up @@ -127,6 +151,55 @@ func Group(group string) ProvideOption {
})
}

// As is a ProvideOption that specifies that the value produced by the
// constructor implements one or more other interfaces.
//
// As expects one or more pointers to the implemented interfaces. Values
// produced by constructors will be made available in the container as
// implementations of all of those interfaces.
//
// For example, the following will make the buffer available in the container
// as io.Reader and io.Writer.
//
// c.Provide(newBuffer, dig.As(new(io.Reader), new(io.Writer)))
//
// That is, the above is equivalent to the following.
//
// c.Provide(func(...) (*bytes.Buffer, io.Reader, io.Writer) {
// b := newBuffer(...)
// return b, b, b
// })
//
// If used with dig.Name, the type produced by the constructor and the types
// specified with dig.As will all use the same name. For example,
//
// c.Provide(newFile, dig.As(new(io.Reader)), dig.Name("temp"))
//
// The above is equivalent to the following.
//
// type Result struct {
// dig.Out
//
// File *os.File `name:"temp"`
// Reader io.Reader `name:"temp"`
// }
//
// c.Provide(func(...) Result {
// f := newFile(...)
// return Result{
// File: f,
// Reader: f,
// }
// })
//
// This option cannot be provided for constructors which produce result
// objects.
func As(i ...interface{}) ProvideOption {
return provideOptionFunc(func(opts *provideOptions) {
opts.As = append(opts.As, i...)
})
}

// An InvokeOption modifies the default behavior of Invoke. It's included for
// future functionality; currently, there are no concrete implementations.
type InvokeOption interface {
Expand Down Expand Up @@ -424,6 +497,7 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) error {
nodeOptions{
ResultName: opts.Name,
ResultGroup: opts.Group,
ResultAs: opts.As,
},
)
if err != nil {
Expand Down Expand Up @@ -535,30 +609,23 @@ func (cv connectionVisitor) Visit(res result) resultVisitor {
path := strings.Join(cv.currentResultPath, ".")

switch r := res.(type) {

case resultSingle:
k := key{name: r.Name, t: r.Type}

if conflict, ok := cv.keyPaths[k]; ok {
*cv.err = fmt.Errorf(
"cannot provide %v from %v: already provided by %v",
k, path, conflict)
if err := cv.checkKey(k, path); err != nil {
*cv.err = err
return nil
}

if ps := cv.c.providers[k]; len(ps) > 0 {
cons := make([]string, len(ps))
for i, p := range ps {
cons[i] = fmt.Sprint(p.Location())
cv.keyPaths[k] = path
for _, asType := range r.As {
k := key{name: r.Name, t: asType}
if err := cv.checkKey(k, path); err != nil {
*cv.err = err
return nil
}

*cv.err = fmt.Errorf(
"cannot provide %v from %v: already provided by %v",
k, path, strings.Join(cons, "; "))
return nil
cv.keyPaths[k] = path
}

cv.keyPaths[k] = path

case resultGrouped:
// we don't really care about the path for this since conflicts are
// okay for group results. We'll track it for the sake of having a
Expand All @@ -570,6 +637,25 @@ func (cv connectionVisitor) Visit(res result) resultVisitor {
return cv
}

func (cv connectionVisitor) checkKey(k key, path string) error {
if conflict, ok := cv.keyPaths[k]; ok {
return fmt.Errorf(
"cannot provide %v from %v: already provided by %v",
k, path, conflict)
}
if ps := cv.c.providers[k]; len(ps) > 0 {
cons := make([]string, len(ps))
for i, p := range ps {
cons[i] = fmt.Sprint(p.Location())
}

return fmt.Errorf(
"cannot provide %v from %v: already provided by %v",
k, path, strings.Join(cons, "; "))
}
return nil
}

// node is a node in the dependency graph. Each node maps to a single
// constructor provided by the user.
//
Expand Down Expand Up @@ -598,9 +684,10 @@ type node struct {

type nodeOptions struct {
// If specified, all values produced by this node have the provided name
// or belong to the specified value group
// belong to the specified value group or implement any of the interfaces.
ResultName string
ResultGroup string
ResultAs []interface{}
}

func newNode(ctor interface{}, opts nodeOptions) (*node, error) {
Expand All @@ -618,6 +705,7 @@ func newNode(ctor interface{}, opts nodeOptions) (*node, error) {
resultOptions{
Name: opts.ResultName,
Group: opts.ResultGroup,
As: opts.ResultAs,
},
)
if err != nil {
Expand Down
Loading

0 comments on commit 287b330

Please sign in to comment.