diff --git a/transform.go b/transform.go index f2c90ff..8ba9449 100644 --- a/transform.go +++ b/transform.go @@ -106,7 +106,7 @@ func Transform(name string, s interface{}) error { return t.Transform(name, s) } -// NewTransformer +// NewTransformer ... func NewTransformer(opts ...TransformerOpt) *TransformerImpl { t := new(TransformerImpl) t.TagName = defaultTagName @@ -126,6 +126,10 @@ func (t *TransformerImpl) Transform(name string, s interface{}) error { return ErrNoPointer } + if val.IsNil() { + return nil // bail out if nil + } + val = val.Elem() if !val.CanAddr() { return ErrNoAddressable @@ -140,8 +144,6 @@ func (t *TransformerImpl) transform(val reflect.Value, field ...FieldLevel) erro valKind := getKind(reflect.Indirect(val)) - fmt.Println(field) - if len(field) > 0 { valKind = getKind(field[0].Field()) } @@ -151,8 +153,6 @@ func (t *TransformerImpl) transform(val reflect.Value, field ...FieldLevel) erro err = t.transformType(field[0]) case reflect.Struct: err = t.transformStruct(val) - case reflect.Interface: - return t.transform(reflect.ValueOf(val.Interface())) default: // we have to work on here for value to pointed to return fmt.Errorf("transformer: unsupported type %s", valKind) @@ -163,14 +163,18 @@ func (t *TransformerImpl) transform(val reflect.Value, field ...FieldLevel) erro // transcodeType func (t *TransformerImpl) transformType(field FieldLevel) error { - fmt.Println("transformType") + for _, f := range field.Funcs() { + fn, ok := internalTransformers[f] + if !ok { + return fmt.Errorf("transformer: function %s does not exist", f) + } - fn, ok := internalTransformers["trim"] - if !ok { - return nil + if err := fn(field); err != nil { + return err + } } - return fn(field) + return nil } // transdecodeStruct @@ -209,8 +213,6 @@ func (t *TransformerImpl) transformStruct(val reflect.Value) error { tag := field.Tag.Get(t.TagName) tag = strings.SplitN(tag, ",", 2)[0] - fmt.Println("here", val.CanAddr()) - if !val.CanAddr() { continue } @@ -230,8 +232,6 @@ func (t *TransformerImpl) transformStruct(val reflect.Value) error { continue } - fmt.Print("tag: ", tag, "\n") - if err := t.transform(val, f); err != nil { return err } diff --git a/transform_test.go b/transform_test.go index 88f3a40..a982a56 100644 --- a/transform_test.go +++ b/transform_test.go @@ -11,10 +11,14 @@ import ( func TestStruct(t *testing.T) { trans := transform.NewTransformer() + type testStruct struct { + Name string `transform:"trim"` + } + tests := []struct { name string - in interface{} - out interface{} + in *testStruct + out *testStruct }{ { name: "nil", @@ -23,19 +27,15 @@ func TestStruct(t *testing.T) { }, { name: "empty", - in: struct{}{}, - out: struct{}{}, + in: &testStruct{}, + out: &testStruct{}, }, { name: "string", - in: struct { - Name string `transform:"trim,lowercase"` - }{ + in: &testStruct{ Name: " test ", }, - out: struct { - Name string `transform:"trim,lowercase"` - }{ + out: &testStruct{ Name: "test", }, }, @@ -43,7 +43,7 @@ func TestStruct(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := trans.Transform(tt.name, &tt.in) + err := trans.Transform(tt.name, tt.in) require.NoError(t, err) require.Equal(t, tt.out, tt.in) })