Skip to content

Commit

Permalink
add InvokeOption for Names
Browse files Browse the repository at this point in the history
addesses uber-go#181
  • Loading branch information
tylersouthwick committed Oct 25, 2021
1 parent bc16c4d commit 1be65d1
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 10 deletions.
47 changes: 44 additions & 3 deletions dig.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type provideOptions struct {
Info *ProvideInfo
As []interface{}
Location *digreflect.Func
Names []string
}

func (o *provideOptions) Validate() error {
Expand Down Expand Up @@ -282,10 +283,40 @@ func LocationForPC(pc uintptr) ProvideOption {
})
}

type invokeOptions struct {
Names []string
}

func (*invokeOptions) Validate() error {
return nil
}

// An InvokeOption modifies the default behavior of Invoke. It's included for
// future functionality; currently, there are no concrete implementations.
type InvokeOption interface {
unimplemented()
applyInvokeOption(*invokeOptions)
}

type invokeOptionFunc func(*invokeOptions)

func (f invokeOptionFunc) applyInvokeOption(opts *invokeOptions) { f(opts) }

type InvokeAndProvideOption interface {
InvokeOption
ProvideOption
}

type namesOption []string

func (n namesOption) applyInvokeOption(opts *invokeOptions) {
opts.Names = n
}
func (n namesOption) applyProvideOption(opts *provideOptions) {
opts.Names = n
}

func Names(names ...string) InvokeAndProvideOption {
return namesOption(names)
}

// Container is a directed acyclic graph of types and their dependencies.
Expand Down Expand Up @@ -566,7 +597,15 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error {
return errf("can't invoke non-function %v (type %v)", function, ftype)
}

pl, err := newParamList(ftype)
var options invokeOptions
for _, o := range opts {
o.applyInvokeOption(&options)
}
if err := options.Validate(); err != nil {
return err
}

pl, err := newParamList(ftype, options.Names)
if err != nil {
return err
}
Expand Down Expand Up @@ -624,6 +663,7 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) error {
ResultGroup: opts.Group,
ResultAs: opts.As,
Location: opts.Location,
ParamNames: opts.Names,
},
)
if err != nil {
Expand Down Expand Up @@ -842,14 +882,15 @@ type nodeOptions struct {
ResultGroup string
ResultAs []interface{}
Location *digreflect.Func
ParamNames []string
}

func newNode(ctor interface{}, opts nodeOptions) (*node, error) {
cval := reflect.ValueOf(ctor)
ctype := cval.Type()
cptr := cval.Pointer()

params, err := newParamList(ctype)
params, err := newParamList(ctype, opts.ParamNames)
if err != nil {
return nil, err
}
Expand Down
58 changes: 58 additions & 0 deletions dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,45 @@ func TestEndToEndSuccess(t *testing.T) {
}), "invoke should succeed, pulling out two named instances")
})

t.Run("named instances can be used to Provide another instance", func(t *testing.T) {
c := New()

type A struct{ idx int }

buildConstructor := func(idx int) func() A {
return func() A { return A{idx: idx} }
}

require.NoError(t, c.Provide(buildConstructor(1), Name("first")))
require.NoError(t, c.Provide(buildConstructor(2), Name("second")))
require.NoError(t, c.Provide(func(a A) int {
return a.idx + 5
}, Names("first")))

require.NoError(t, c.Invoke(func(i int) {
assert.Equal(t, 6, i)
}), "invoke should succeed, pulling out one named instances")
})

t.Run("named instances can be invoked Name option", func(t *testing.T) {
c := New()

type A struct{ idx int }

buildConstructor := func(idx int) func() A {
return func() A { return A{idx: idx} }
}

require.NoError(t, c.Provide(buildConstructor(1), Name("first")))
require.NoError(t, c.Provide(buildConstructor(2), Name("second")))
require.NoError(t, c.Provide(buildConstructor(3), Name("third")))

require.NoError(t, c.Invoke(func(a1 A, a3 A) {
assert.Equal(t, 1, a1.idx)
assert.Equal(t, 3, a3.idx)
}, Names("first", "third")), "invoke should succeed, using two named instances")
})

t.Run("named and unnamed instances coexist", func(t *testing.T) {
c := New()
type A struct{ idx int }
Expand All @@ -561,6 +600,25 @@ func TestEndToEndSuccess(t *testing.T) {
}))
})

t.Run("named and unnamed instances can be invoked with Names option", func(t *testing.T) {
c := New()

type A struct{ idx int }

buildConstructor := func(idx int) func() A {
return func() A { return A{idx: idx} }
}

require.NoError(t, c.Provide(buildConstructor(1), Name("first")))
require.NoError(t, c.Provide(buildConstructor(2), Name("second")))
require.NoError(t, c.Provide(buildConstructor(3)))

require.NoError(t, c.Invoke(func(a1 A, a3 A) {
assert.Equal(t, 1, a1.idx)
assert.Equal(t, 3, a3.idx)
}, Names("first")), "invoke should succeed, using two named instances")
})

t.Run("named instances recurse", func(t *testing.T) {
c := New()
type A struct{ idx int }
Expand Down
21 changes: 16 additions & 5 deletions param.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ var (

// newParam builds a param from the given type. If the provided type is a
// dig.In struct, an paramObject will be returned.
func newParam(t reflect.Type) (param, error) {
func newParam(t reflect.Type, paramName string) (param, error) {
switch {
case IsOut(t) || (t.Kind() == reflect.Ptr && IsOut(t.Elem())) || embedsType(t, _outPtrType):
return nil, errf("cannot depend on result objects", "%v embeds a dig.Out", t)
case IsIn(t):
if paramName != "" {
return nil, errf("cannot have a paramName (%s) with a struct that has dig.In", paramName)
}
return newParamObject(t)
case embedsType(t, _inPtrType):
return nil, errf(
Expand All @@ -77,7 +80,7 @@ func newParam(t reflect.Type) (param, error) {
"cannot depend on a pointer to a parameter object, use a value instead",
"%v is a pointer to a struct that embeds dig.In", t)
default:
return paramSingle{Type: t}, nil
return paramSingle{Type: t, Name: paramName}, nil
}
}

Expand Down Expand Up @@ -158,7 +161,7 @@ func (pl paramList) DotParam() []*dot.Param {
//
// Variadic arguments of a constructor are ignored and not included as
// dependencies.
func newParamList(ctype reflect.Type) (paramList, error) {
func newParamList(ctype reflect.Type, names []string) (paramList, error) {
numArgs := ctype.NumIn()
if ctype.IsVariadic() {
// NOTE: If the function is variadic, we skip the last argument
Expand All @@ -171,8 +174,16 @@ func newParamList(ctype reflect.Type) (paramList, error) {
Params: make([]param, 0, numArgs),
}

if numArgs < len(names) {
return pl, errf("can't create a constructor with more names=%s than args=%s", names, ctype)
}

for i := 0; i < numArgs; i++ {
p, err := newParam(ctype.In(i))
name := ""
if i < len(names) {
name = names[i]
}
p, err := newParam(ctype.In(i), name)
if err != nil {
return pl, errf("bad argument %d", i+1, err)
}
Expand Down Expand Up @@ -370,7 +381,7 @@ func newParamObjectField(idx int, f reflect.StructField) (paramObjectField, erro

default:
var err error
p, err = newParam(f.Type)
p, err = newParam(f.Type, "")
if err != nil {
return pof, err
}
Expand Down
4 changes: 2 additions & 2 deletions param_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
)

func TestParamListBuild(t *testing.T) {
p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }))
p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }), []string{})
require.NoError(t, err)
assert.Panics(t, func() {
p.Build(New())
Expand Down Expand Up @@ -238,7 +238,7 @@ func TestParamVisitorChecksEverything(t *testing.T) {

pl, err := newParamList(reflect.TypeOf(func(io.Reader, params, io.Writer) {
t.Fatalf("this function should not be called")
}))
}), []string{})
require.NoError(t, err)

idx := 0
Expand Down

0 comments on commit 1be65d1

Please sign in to comment.