Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a dig.As ProvideOption #252

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 109 additions & 21 deletions dig.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,19 @@ func (f optionFunc) applyOption(c *Container) { f(c) }
type provideOptions struct {
Name string
Group string
As []interface{}
}

func (o *provideOptions) Validate() error {
if len(o.Group) > 0 && len(o.Name) > 0 {
return fmt.Errorf(
"cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group)
if len(o.Group) > 0 {
if len(o.Name) > 0 {
return fmt.Errorf(
"cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group)
}
if len(o.As) > 0 {
return fmt.Errorf(
"cannot use dig.As with value groups: dig.As provided with group:%q", o.Group)
}
}

// Names must be representable inside a backquoted string. The only
Expand All @@ -80,6 +87,23 @@ func (o *provideOptions) Validate() error {
if strings.ContainsRune(o.Group, '`') {
return fmt.Errorf("invalid dig.Group(%q): group names cannot contain backquotes", o.Group)
}

for _, i := range o.As {
t := reflect.TypeOf(i)

if t == nil {
return fmt.Errorf("invalid dig.As(nil): argument must be a pointer to an interface")
}

if t.Kind() != reflect.Ptr {
return fmt.Errorf("invalid dig.As(%v): argument must be a pointer to an interface", t)
}

pointingTo := t.Elem()
if pointingTo.Kind() != reflect.Interface {
return fmt.Errorf("invalid dig.As(*%v): argument must be a pointer to an interface", pointingTo)
}
}
return nil
}

Expand Down Expand Up @@ -127,6 +151,55 @@ func Group(group string) ProvideOption {
})
}

// As is a ProvideOption that specifies that the value produced by the
// constructor implements one or more other interfaces.
//
// As expects one or more pointers to the implemented interfaces. Values
// produced by constructors will be made available in the container as
// implementations of all of those interfaces.
//
// For example, the following will make the buffer available in the container
// as io.Reader and io.Writer.
//
// c.Provide(newBuffer, dig.As(new(io.Reader), new(io.Writer)))
//
// That is, the above is equivalent to the following.
//
// c.Provide(func(...) (*bytes.Buffer, io.Reader, io.Writer) {
// b := newBuffer(...)
// return b, b, b
// })
//
// If used with dig.Name, the type produced by the constructor and the types
// specified with dig.As will all use the same name. For example,
//
// c.Provide(newFile, dig.As(new(io.Reader)), dig.Name("temp"))
//
// The above is equivalent to the following.
//
// type Result struct {
// dig.Out
//
// File *os.File `name:"temp"`
// Reader io.Reader `name:"temp"`
// }
//
// c.Provide(func(...) Result {
// f := newFile(...)
// return Result{
// File: f,
// Reader: f,
// }
// })
//
// This option cannot be provided for constructors which produce result
// objects.
func As(i ...interface{}) ProvideOption {
return provideOptionFunc(func(opts *provideOptions) {
opts.As = append(opts.As, i...)
})
}

// An InvokeOption modifies the default behavior of Invoke. It's included for
// future functionality; currently, there are no concrete implementations.
type InvokeOption interface {
Expand Down Expand Up @@ -424,6 +497,7 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) error {
nodeOptions{
ResultName: opts.Name,
ResultGroup: opts.Group,
ResultAs: opts.As,
},
)
if err != nil {
Expand Down Expand Up @@ -535,30 +609,23 @@ func (cv connectionVisitor) Visit(res result) resultVisitor {
path := strings.Join(cv.currentResultPath, ".")

switch r := res.(type) {

case resultSingle:
k := key{name: r.Name, t: r.Type}

if conflict, ok := cv.keyPaths[k]; ok {
*cv.err = fmt.Errorf(
"cannot provide %v from %v: already provided by %v",
k, path, conflict)
if err := cv.checkKey(k, path); err != nil {
*cv.err = err
return nil
}

if ps := cv.c.providers[k]; len(ps) > 0 {
cons := make([]string, len(ps))
for i, p := range ps {
cons[i] = fmt.Sprint(p.Location())
cv.keyPaths[k] = path
for _, asType := range r.As {
k := key{name: r.Name, t: asType}
if err := cv.checkKey(k, path); err != nil {
*cv.err = err
return nil
}

*cv.err = fmt.Errorf(
"cannot provide %v from %v: already provided by %v",
k, path, strings.Join(cons, "; "))
return nil
cv.keyPaths[k] = path
}

cv.keyPaths[k] = path

case resultGrouped:
// we don't really care about the path for this since conflicts are
// okay for group results. We'll track it for the sake of having a
Expand All @@ -570,6 +637,25 @@ func (cv connectionVisitor) Visit(res result) resultVisitor {
return cv
}

func (cv connectionVisitor) checkKey(k key, path string) error {
if conflict, ok := cv.keyPaths[k]; ok {
return fmt.Errorf(
"cannot provide %v from %v: already provided by %v",
k, path, conflict)
}
if ps := cv.c.providers[k]; len(ps) > 0 {
cons := make([]string, len(ps))
for i, p := range ps {
cons[i] = fmt.Sprint(p.Location())
}

return fmt.Errorf(
"cannot provide %v from %v: already provided by %v",
k, path, strings.Join(cons, "; "))
}
return nil
}

// node is a node in the dependency graph. Each node maps to a single
// constructor provided by the user.
//
Expand Down Expand Up @@ -598,9 +684,10 @@ type node struct {

type nodeOptions struct {
// If specified, all values produced by this node have the provided name
// or belong to the specified value group
// belong to the specified value group or implement any of the interfaces.
ResultName string
ResultGroup string
ResultAs []interface{}
}

func newNode(ctor interface{}, opts nodeOptions) (*node, error) {
Expand All @@ -618,6 +705,7 @@ func newNode(ctor interface{}, opts nodeOptions) (*node, error) {
resultOptions{
Name: opts.ResultName,
Group: opts.ResultGroup,
As: opts.ResultAs,
},
)
if err != nil {
Expand Down
Loading