Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to invoke a function between multiple containers #289

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions dig.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,60 @@ func Group(group string) ProvideOption {
})
}

// GroupInvoke runs the given function after instantiating its dependencies within multiple containers
func GroupInvoke(function interface{}, containers ...*Container) error {
arguments, err := resolve(function, containers...)
if err != nil {
return err
}

reflect.ValueOf(function).Call(arguments)

return nil
}

func resolve(function interface{}, containers ...*Container) ([]reflect.Value, error) {
ftype := reflect.TypeOf(function)
arguments := ftype.NumIn()
result := make([]reflect.Value, arguments)
if ftype == nil {
return nil, errors.New("can't invoke an untyped nil")
}
if ftype.Kind() != reflect.Func {
return nil, errf("can't invoke non-function %v (type %v)", function, ftype)
}

pl, err := newParamList(ftype)
if err != nil {
return nil, err
}

for _, c := range containers {
if !c.isVerifiedAcyclic {
if err := c.verifyAcyclic(); err != nil {
return nil, err
}
}

args, err := pl.UnsafeBuildList(c)
if err != nil {
return nil, err
}

for i, a := range args {
if a.IsValid() {
result[i] = reflect.ValueOf(a.Interface())
}
}
}

if len(result) != arguments {
return nil, errors.New("parameters count does not match")
}

return result, nil
}

// ID is a unique integer representing the constructor node in the dependency graph.
type ID int

Expand Down
44 changes: 44 additions & 0 deletions dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3112,3 +3112,47 @@ func TestProvideInfoOption(t *testing.T) {
assert.Equal(t, "*dig.type4", info2.Outputs[0].String())
})
}

func TestGroupInvoke(t *testing.T) {
type TestParam struct {
Name string
Value string
}

type TestParam1 struct {
AdditionaInfo string
}

singletonIOC := New()
singletonIOC.Provide(func() *TestParam {
return &TestParam{
Name: "TestName",
Value: "TestValue",
}
})

customIOC := New()
customIOC.Provide(func() *TestParam1 {
return &TestParam1{
AdditionaInfo: "Some info",
}
})

function := func(p *TestParam, p1 *TestParam1) {
res1 := &TestParam{
Name: "TestName",
Value: "TestValue",
}

res2 := &TestParam1{
AdditionaInfo: "Some info",
}

assert.Equal(t, res1, p)
assert.Equal(t, res2, p1)
}

if err := GroupInvoke(function, singletonIOC, customIOC); err != nil {
assert.FailNow(t, err.Error())
}
}
14 changes: 14 additions & 0 deletions param.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) {
return args, nil
}

// UnsafeBuildList returns an ordered list of values which may be passed directly
// to the underlying constructor without interruption in case of missing field.
func (pl paramList) UnsafeBuildList(c containerStore) ([]reflect.Value, error) {
args := make([]reflect.Value, len(pl.Params))
for i, p := range pl.Params {
var err error
args[i], err = p.Build(c)
if err != nil {
continue
}
}
return args, nil
}

// paramSingle is an explicitly requested type, optionally with a name.
//
// This object must be present in the graph as-is unless it's specified as
Expand Down