From ed92c646594047e2f815d99d631748afdb6e5bbd Mon Sep 17 00:00:00 2001 From: Roberto Ferro Date: Thu, 19 Aug 2021 16:54:15 +0200 Subject: [PATCH 1/5] Implemented GroupInvoke --- dig.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++ dig_test.go | 35 ++++++++++++++++++++++++++++++++++ param.go | 14 ++++++++++++++ 3 files changed, 103 insertions(+) diff --git a/dig.go b/dig.go index dc3c2363..0a4da625 100644 --- a/dig.go +++ b/dig.go @@ -23,6 +23,7 @@ package dig import ( "errors" "fmt" + "log" "math/rand" "reflect" "sort" @@ -129,6 +130,59 @@ func Group(group string) ProvideOption { }) } +// GroupInvoke runs the given function after instantiating its dependencies within multiple containers +func GroupInvoke(function interface{}, containers ...*Container) error { + arguments, err := resolve(function, containers...) + if err != nil { + return err + } + + reflect.ValueOf(function).Call(arguments) + return nil +} + +func resolve(function interface{}, containers ...*Container) ([]reflect.Value, error) { + var result []reflect.Value + ftype := reflect.TypeOf(function) + arguments := ftype.NumIn() + if ftype == nil { + return nil, errors.New("can't invoke an untyped nil") + } + if ftype.Kind() != reflect.Func { + return nil, errf("can't invoke non-function %v (type %v)", function, ftype) + } + + pl, err := newParamList(ftype) + if err != nil { + return nil, err + } + + for _, c := range containers { + if !c.isVerifiedAcyclic { + if err := c.verifyAcyclic(); err != nil { + return nil, err + } + } + + args, err := pl.UnsafeBuildList(c) + if err != nil { + return nil, err + } + + for _, a := range args { + if a.IsValid() { + result = append(result, reflect.ValueOf(a.Interface())) + } + } + } + + if len(result) != arguments { + log.Fatal("Parameters count does not match") + } + + return result, nil +} + // ID is a unique integer representing the constructor node in the dependency graph. type ID int diff --git a/dig_test.go b/dig_test.go index a907f619..973d930f 100644 --- a/dig_test.go +++ b/dig_test.go @@ -26,6 +26,7 @@ import ( "fmt" "io" "io/ioutil" + "log" "math/rand" "os" "reflect" @@ -3112,3 +3113,37 @@ func TestProvideInfoOption(t *testing.T) { assert.Equal(t, "*dig.type4", info2.Outputs[0].String()) }) } + +func TestGroupInvoke(t *testing.T) { + type TestParam struct { + Name string + Value string + } + + type TestParam1 struct { + AdditionaInfo string + } + + singletonIOC := New() + singletonIOC.Provide(func() *TestParam { + return &TestParam{ + Name: "TestName", + Value: "TestValue", + } + }) + + customIOC := New() + customIOC.Provide(func() *TestParam1 { + return &TestParam1{ + AdditionaInfo: "Some info", + } + }) + + function := func(p *TestParam, p1 *TestParam1) { + fmt.Print("Test") + } + + if err := GroupInvoke(function, singletonIOC, customIOC); err != nil { + log.Fatal(err) + } +} diff --git a/param.go b/param.go index 0979228a..e8def84f 100644 --- a/param.go +++ b/param.go @@ -206,6 +206,20 @@ func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) { return args, nil } +// UnsafeBuildList returns an ordered list of values which may be passed directly +// to the underlying constructor without interruption in case of missing field. +func (pl paramList) UnsafeBuildList(c containerStore) ([]reflect.Value, error) { + args := make([]reflect.Value, len(pl.Params)) + for i, p := range pl.Params { + var err error + args[i], err = p.Build(c) + if err != nil { + continue + } + } + return args, nil +} + // paramSingle is an explicitly requested type, optionally with a name. // // This object must be present in the graph as-is unless it's specified as From 1add9d6756f9fba4828df86bed31d04c2e6c3d89 Mon Sep 17 00:00:00 2001 From: Roberto Ferro Date: Thu, 19 Aug 2021 17:00:39 +0200 Subject: [PATCH 2/5] Fix GroupInvoke test --- dig_test.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dig_test.go b/dig_test.go index 973d930f..64c9ff86 100644 --- a/dig_test.go +++ b/dig_test.go @@ -3140,7 +3140,17 @@ func TestGroupInvoke(t *testing.T) { }) function := func(p *TestParam, p1 *TestParam1) { - fmt.Print("Test") + res1 := &TestParam{ + Name: "TestName", + Value: "TestValue", + } + + res2 := &TestParam1{ + AdditionaInfo: "Some info", + } + + assert.Equal(t, res1, p) + assert.Equal(t, res2, p1) } if err := GroupInvoke(function, singletonIOC, customIOC); err != nil { From 72e741b9b0ee53c542a872b96ec342ff6305494a Mon Sep 17 00:00:00 2001 From: Roberto Ferro Date: Thu, 19 Aug 2021 17:18:29 +0200 Subject: [PATCH 3/5] Fix GroupInvoke test --- dig_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dig_test.go b/dig_test.go index 64c9ff86..277af74b 100644 --- a/dig_test.go +++ b/dig_test.go @@ -26,7 +26,6 @@ import ( "fmt" "io" "io/ioutil" - "log" "math/rand" "os" "reflect" @@ -3154,6 +3153,6 @@ func TestGroupInvoke(t *testing.T) { } if err := GroupInvoke(function, singletonIOC, customIOC); err != nil { - log.Fatal(err) + assert.FailNow(t, err.Error()) } } From 8cd3b2cd6df41c73433bc089c46dc46f507a020b Mon Sep 17 00:00:00 2001 From: Roberto Ferro Date: Thu, 19 Aug 2021 17:23:06 +0200 Subject: [PATCH 4/5] Removed log --- dig.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dig.go b/dig.go index 0a4da625..ae85b1af 100644 --- a/dig.go +++ b/dig.go @@ -23,7 +23,6 @@ package dig import ( "errors" "fmt" - "log" "math/rand" "reflect" "sort" @@ -138,6 +137,7 @@ func GroupInvoke(function interface{}, containers ...*Container) error { } reflect.ValueOf(function).Call(arguments) + return nil } @@ -177,7 +177,7 @@ func resolve(function interface{}, containers ...*Container) ([]reflect.Value, e } if len(result) != arguments { - log.Fatal("Parameters count does not match") + return nil, errors.New("parameters count does not match") } return result, nil From a51d0737ab5d1074aa75cc839a4731dc18d9a754 Mon Sep 17 00:00:00 2001 From: Roberto Ferro Date: Fri, 20 Aug 2021 11:02:41 +0200 Subject: [PATCH 5/5] Fix argument order issue --- dig.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dig.go b/dig.go index ae85b1af..18baf2f9 100644 --- a/dig.go +++ b/dig.go @@ -142,9 +142,9 @@ func GroupInvoke(function interface{}, containers ...*Container) error { } func resolve(function interface{}, containers ...*Container) ([]reflect.Value, error) { - var result []reflect.Value ftype := reflect.TypeOf(function) arguments := ftype.NumIn() + result := make([]reflect.Value, arguments) if ftype == nil { return nil, errors.New("can't invoke an untyped nil") } @@ -169,9 +169,9 @@ func resolve(function interface{}, containers ...*Container) ([]reflect.Value, e return nil, err } - for _, a := range args { + for i, a := range args { if a.IsValid() { - result = append(result, reflect.ValueOf(a.Interface())) + result[i] = reflect.ValueOf(a.Interface()) } } }