diff --git a/annotated.go b/annotated.go index 0e72860e3..def263caf 100644 --- a/annotated.go +++ b/annotated.go @@ -377,15 +377,12 @@ var _ Annotation = resultTagsAnnotation{} // If the tag is invalid and has mismatched quotation for example, // (`tag_name:"tag_value') , this will return an error. func (rt resultTagsAnnotation) apply(ann *annotated) error { - if len(ann.ResultTags) > 0 { - return errors.New("cannot apply more than one line of ResultTags") - } for _, tag := range rt.tags { if err := verifyAnnotateTag(tag); err != nil { return err } } - ann.ResultTags = rt.tags + ann.ResultTags = append(ann.ResultTags, rt.tags) return nil } @@ -426,12 +423,16 @@ func (rt resultTagsAnnotation) results(ann *annotated) ( // if there's no Out struct among the return types, there was no As annotation applied // just replace original result types with an Out struct and apply tags var ( - newOut outStructInfo - existingOuts []reflect.Type + newOut outStructInfo + existingOuts []reflect.Type + existingOutsMapping = make(map[int][]reflect.Type) ) newOut.Fields = []reflect.StructField{_outAnnotationField} newOut.Offsets = []int{} + // to prevent duplicate applying of the same tag to the same type, it is kept in the following format + // {"foo": {"fmt.Stringer": struct{}{}, "myStringer": struct{}{}}, "bar": {"fmt.Stringer": struct{}{}}} + retaggedTypeMap := map[string]map[string]struct{}{} for i, t := range types { if !isOut(t) { @@ -448,26 +449,92 @@ func (rt resultTagsAnnotation) results(ann *annotated) ( newOut.Fields = append(newOut.Fields, field) continue } - // this must be from an As annotation - // apply the tags to the existing type + // this must be from an As annotation or a ResultTags annotation + // apply the tags to the existing type if it comes from an As annotation taggedFields := make([]reflect.StructField, t.NumField()) taggedFields[0] = _outAnnotationField + // apply the tags to the existing type after replication the existing type if it comes from a ResultTags annotation + originalTaggedFields := make([]reflect.StructField, t.NumField()) + originalTaggedFields[0] = _outAnnotationField + + var newlyTagged bool + var hasBeenTagged bool for j, tag := range rt.tags { if j+1 < t.NumField() { field := t.Field(j + 1) + // if the field has already been tagged by ResultTags annotation, avoid overwriting it + _, taggedName := field.Tag.Lookup("name") + _, taggedGroup := field.Tag.Lookup("group") + if taggedName || taggedGroup { + originalTaggedFields[j+1] = reflect.StructField{ + Name: field.Name, + Type: field.Type, + Tag: field.Tag, + } + hasBeenTagged = true + } + + structTag := reflect.StructTag(tag) + if typeNames, ok := retaggedTypeMap[structTag.Get("name")]; ok { + if _, ok := typeNames[field.Type.String()]; ok { + continue + } + } + if typeNames, ok := retaggedTypeMap[structTag.Get("group")]; ok { + if _, ok := typeNames[field.Type.String()]; ok { + continue + } + } + if n, ok := structTag.Lookup("name"); ok { + typeNames, ok := retaggedTypeMap[n] + if !ok { + typeNames = make(map[string]struct{}) + retaggedTypeMap[n] = typeNames + } + typeNames[field.Type.String()] = struct{}{} + } + if g, ok := structTag.Lookup("group"); ok { + typeNames, ok := retaggedTypeMap[g] + if !ok { + typeNames = make(map[string]struct{}) + retaggedTypeMap[g] = typeNames + } + typeNames[field.Type.String()] = struct{}{} + } + + if hasBeenTagged && !taggedName && !taggedGroup { + // if other fields are already tagged and this field is untagged, apply the new tag + originalTaggedFields[j+1] = reflect.StructField{ + Name: field.Name, + Type: field.Type, + Tag: structTag, + } + continue + } taggedFields[j+1] = reflect.StructField{ Name: field.Name, Type: field.Type, - Tag: reflect.StructTag(tag), + Tag: structTag, } + newlyTagged = true } } - existingOuts = append(existingOuts, reflect.StructOf(taggedFields)) + currentTypeExistingOuts := make([]reflect.Type, 0, 2) + if hasBeenTagged { + currentTypeExistingOuts = append(currentTypeExistingOuts, reflect.StructOf(originalTaggedFields)) + } + if newlyTagged { + currentTypeExistingOuts = append(currentTypeExistingOuts, reflect.StructOf(taggedFields)) + } + existingOutsMapping[i] = currentTypeExistingOuts + existingOuts = append(existingOuts, currentTypeExistingOuts...) } resType := reflect.StructOf(newOut.Fields) - - outTypes := []reflect.Type{resType} + var outTypes []reflect.Type + if len(newOut.Fields) > 1 { + outTypes = append(outTypes, resType) + } // append existing outs back to outTypes outTypes = append(outTypes, existingOuts...) if hasError { @@ -479,9 +546,10 @@ func (rt resultTagsAnnotation) results(ann *annotated) ( outErr error outResults []reflect.Value ) - outResults = append(outResults, reflect.New(resType).Elem()) - - tIdx := 0 + if len(newOut.Fields) > 1 { + outResults = append(outResults, reflect.New(resType).Elem()) + } + existingOutResults := make([]reflect.Value, 0, len(existingOuts)) for i, r := range results { if i == len(results)-1 && hasError { // If hasError and this is the last item, @@ -501,21 +569,22 @@ func (rt resultTagsAnnotation) results(ann *annotated) ( // to prevent panic from setting fx.Out to // a value. outResults[0].Field(fieldIdx).Set(r) + continue } - continue } if isOut(r.Type()) { - tIdx++ - if tIdx < len(outTypes) { - newResult := reflect.New(outTypes[tIdx]).Elem() - for j := 1; j < outTypes[tIdx].NumField(); j++ { + for _, existingOuts := range existingOutsMapping[i] { + newResult := reflect.New(existingOuts).Elem() + for j := 1; j < existingOuts.NumField(); j++ { newResult.Field(j).Set(r.Field(j)) } - outResults = append(outResults, newResult) + existingOutResults = append(existingOutResults, newResult) } } } + outResults = append(outResults, existingOutResults...) + if hasError { if outErr != nil { outResults = append(outResults, reflect.ValueOf(outErr)) @@ -1528,7 +1597,7 @@ type annotated struct { Target interface{} Annotations []Annotation ParamTags []string - ResultTags []string + ResultTags [][]string As [][]asType From []reflect.Type FuncPtr uintptr diff --git a/annotated_test.go b/annotated_test.go index e7defa76c..c5b9143ef 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -1431,7 +1431,7 @@ func TestAnnotate(t *testing.T) { t.Run("specify two ResultTags", func(t *testing.T) { t.Parallel() - app := NewForTest(t, + app := fxtest.New(t, fx.Provide( // This should just leave newA as it is. fx.Annotate( @@ -1440,12 +1440,62 @@ func TestAnnotate(t *testing.T) { fx.ResultTags(`name:"AA"`), ), ), - fx.Invoke(newB), + fx.Invoke( + fx.Annotate(func(a, aa *a) (*b, *b) { + return newB(a), newB(aa) + }, fx.ParamTags(`name:"A"`, `name:"AA"`))), ) err := app.Err() - require.Error(t, err) - assert.Contains(t, err.Error(), "encountered error while applying annotation using fx.Annotate to go.uber.org/fx_test.TestAnnotate.func1(): cannot apply more than one line of ResultTags") + require.NoError(t, err) + defer app.RequireStart().RequireStop() + }) + + t.Run("specify two ResultTags containing multiple tags", func(t *testing.T) { + t.Parallel() + + app := fxtest.New(t, + fx.Provide( + fx.Annotate( + func() (*a, *b) { + return newA(), newB(&a{}) + }, + fx.ResultTags(`name:"A"`, `name:"B"`), + fx.ResultTags(`name:"AA"`, `name:"BB"`), + ), + ), + fx.Invoke( + fx.Annotate(func(a, aa *a, b, bb *b) (*b, *b, *c, *c) { + return newB(a), newB(aa), newC(b), newC(b) + }, fx.ParamTags(`name:"A"`, `name:"AA"`, `name:"B"`, `name:"BB"`))), + ) + + err := app.Err() + require.NoError(t, err) + defer app.RequireStart().RequireStop() + }) + + t.Run("specify Three ResultTags", func(t *testing.T) { + t.Parallel() + + app := fxtest.New(t, + fx.Provide( + fx.Annotate( + newA, + fx.ResultTags(`name:"A"`), + fx.ResultTags(`name:"AA"`), + fx.ResultTags(`name:"AAA"`), + ), + ), + fx.Invoke( + fx.Annotate(func(a, aa, aaa *a) (*b, *b, *b) { + return newB(a), newB(aa), newB(aaa) + }, fx.ParamTags(`name:"A"`, `name:"AA"`, `name:"AAA"`))), + ) + + err := app.Err() + require.NoError(t, err) + defer app.RequireStart().RequireStop() }) t.Run("annotate with a non-nil error", func(t *testing.T) { diff --git a/app_test.go b/app_test.go index 198000cdc..f476bcb2b 100644 --- a/app_test.go +++ b/app_test.go @@ -329,7 +329,7 @@ func TestNewApp(t *testing.T) { // cannot provide fx_test.t1[name="foo"] from [0].Field0: // already provided by "reflect".makeFuncStub (/.../reflect/asm_amd64.s:30) assert.Contains(t, err.Error(), `fx.Provide(fx.Annotate(`) - assert.Contains(t, err.Error(), `fx.ResultTags(["name:\"foo\""])`) + assert.Contains(t, err.Error(), `fx.ResultTags([["name:\"foo\""]])`) assert.Contains(t, err.Error(), "already provided") })