diff --git a/README.md b/README.md index fcf6ffc..531702a 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,10 @@ import "github.com/itzg/go-flagsfiller" - `net.IP` parse via net.ParseIP() - `net.IPNet` parse via net.ParseCIDR() - `net.HardwareAddr` parse via net.ParseMAC() + - and all types that implement encoding.TextUnmarshaler interface - Optionally set flag values from environment variables. Similar to flag names, environment variable names are derived automatically from the field names - New types could be supported via user code, via `RegisterSimpleType(ConvertFunc)`, check [time.go](time.go) and [net.go](net.go) to see how it works + - note: in case of a registered type also implements encoding.TextUnmarshaler, then registered type's ConvertFunc is preferred ## Quick example diff --git a/addtional_test.go b/addtional_test.go index 6fd033a..b50a475 100644 --- a/addtional_test.go +++ b/addtional_test.go @@ -3,6 +3,7 @@ package flagsfiller_test import ( "flag" "net" + "net/netip" "testing" "time" @@ -13,7 +14,7 @@ import ( func TestTime(t *testing.T) { type Config struct { - T time.Time `layout:"2006-Jan-02==15:04:05"` + T time.Time `default:"2010-Oct-01==10:02:03" layout:"2006-Jan-02==15:04:05"` } var config Config @@ -24,9 +25,15 @@ func TestTime(t *testing.T) { err := filler.Fill(&flagset, &config) require.NoError(t, err) + //test default tag + err = flagset.Parse([]string{}) + require.NoError(t, err) + expeted, _ := time.Parse("2006-Jan-02==15:04:05", "2010-Oct-01==10:02:03") + assert.Equal(t, expeted, config.T) + err = flagset.Parse([]string{"-t", "2016-Dec-13==16:03:02"}) require.NoError(t, err) - expeted, _ := time.Parse("2006-01-02 15:04:05", "2016-12-13 16:03:02") + expeted, _ = time.Parse("2006-01-02 15:04:05", "2016-12-13 16:03:02") assert.Equal(t, expeted, config.T) } @@ -86,3 +93,27 @@ func TestIPNet(t *testing.T) { _, expected, _ := net.ParseCIDR("192.168.1.0/24") assert.Equal(t, *expected, config.Prefix) } + +func TestTextUnmarshalerType(t *testing.T) { + type Config struct { + Addr netip.Addr `default:"9.9.9.9"` + } + + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + //test default tag + err = flagset.Parse([]string{}) + require.NoError(t, err) + assert.Equal(t, netip.AddrFrom4([4]byte{9, 9, 9, 9}), config.Addr) + + err = flagset.Parse([]string{"-addr", "1.2.3.4"}) + require.NoError(t, err) + + assert.Equal(t, netip.AddrFrom4([4]byte{1, 2, 3, 4}), config.Addr) +} diff --git a/flagset.go b/flagset.go index f377dd5..6c85d28 100644 --- a/flagset.go +++ b/flagset.go @@ -1,6 +1,7 @@ package flagsfiller import ( + "encoding" "flag" "fmt" "os" @@ -55,13 +56,28 @@ func (f *FlagSetFiller) Fill(flagSet *flag.FlagSet, from interface{}) error { } } -func isSupportedStruct(name string) bool { - _, ok := extendedTypes[name] - return ok +func isSupportedStruct(in any) bool { + t := reflect.TypeOf(in) + _, ok := extendedTypes[getTypeName(t)] + if ok { + return true + } + if t.Kind() != reflect.Pointer { + val := reflect.ValueOf(in) + t = val.Addr().Type() + } + if t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) { + RegisterTextUnmarshaler(in) + return true + } + return false } func getTypeName(t reflect.Type) string { - return t.PkgPath() + "." + t.Name() + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + return fmt.Sprint(t) } func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string, @@ -97,13 +113,15 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string, switch field.Type.Kind() { case reflect.Struct: - fieldTypeName := getTypeName(field.Type) - if isSupportedStruct(fieldTypeName) { - err := handleDefault(field, fieldValue) - if err != nil { - return err + // fieldTypeName := getTypeName(field.Type) + if field.IsExported() { + if isSupportedStruct(fieldValue.Addr().Interface()) { + err := handleDefault(field, fieldValue) + if err != nil { + return err + } + continue } - continue } err := f.walkFields(flagSet, prefix+field.Name, fieldValue, field.Type) if err != nil { @@ -112,17 +130,19 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string, case reflect.Ptr: if fieldValue.CanSet() && field.Type.Elem().Kind() == reflect.Struct { - fieldTypeName := getTypeName(field.Type.Elem()) + // fieldTypeName := getTypeName(field.Type.Elem()) // fill the pointer with a new struct of their type if it is nil if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.Type.Elem())) } - if isSupportedStruct(fieldTypeName) { - err := handleDefault(field, fieldValue.Elem()) - if err != nil { - return err + if field.IsExported() { + if isSupportedStruct(fieldValue.Interface()) { + err := handleDefault(field, fieldValue.Elem()) + if err != nil { + return err + } + continue } - continue } err := f.walkFields(flagSet, field.Name, fieldValue.Elem(), field.Type.Elem()) @@ -175,11 +195,11 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{} } else { renamed = f.options.renameLongName(name) } - typeName := getTypeName(t) - // go through all supported structs - if handler, ok := extendedTypes[typeName]; ok { + if isSupportedStruct(fieldRef) { + handler := extendedTypes[getTypeName(t)] err = handler(tag, fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases) + } switch { diff --git a/general.go b/general.go index fce558d..9a92bc3 100644 --- a/general.go +++ b/general.go @@ -31,7 +31,6 @@ func processGeneral[T any](fieldRef interface{}, val flagVal[T], hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { - casted := fieldRef.(*T) if hasDefaultTag { *casted, err = val.StrConverter(tagDefault) diff --git a/simple.go b/simple.go index 66f1702..0e4f8e1 100644 --- a/simple.go +++ b/simple.go @@ -32,7 +32,7 @@ func (v *simpleType[T]) String() string { if v.val == nil { return fmt.Sprint(nil) } - return fmt.Sprintf("%v", *v.val) + return fmt.Sprint(v.val) } func (v *simpleType[T]) StrConverter(s string) (T, error) { diff --git a/txtunmarshaler.go b/txtunmarshaler.go new file mode 100644 index 0000000..6e325c7 --- /dev/null +++ b/txtunmarshaler.go @@ -0,0 +1,60 @@ +// This file implements support for all types that support interface encoding.TextUnmarshaler +package flagsfiller + +import ( + "encoding" + "flag" + "fmt" + "reflect" + "strings" +) + +// RegisterTextUnmarshaler use is optional, since flagsfiller will automatically register the types implement encoding.TextUnmarshaler it encounters +func RegisterTextUnmarshaler(in any) { + base := textUnmarshalerType{} + extendedTypes[getTypeName(reflect.TypeOf(in).Elem())] = base.process +} + +type textUnmarshalerType struct { + val encoding.TextUnmarshaler +} + +// String implements flag.Value interface +func (tv *textUnmarshalerType) String() string { + if tv.val == nil { + return fmt.Sprint(nil) + } + return fmt.Sprint(tv.val) +} + +// Set implements flag.Value interface +func (tv *textUnmarshalerType) Set(s string) error { + return tv.val.UnmarshalText([]byte(s)) +} + +func (tv *textUnmarshalerType) process(tag reflect.StructTag, fieldRef interface{}, + hasDefaultTag bool, tagDefault string, + flagSet *flag.FlagSet, renamed string, + usage string, aliases string) error { + v, ok := fieldRef.(encoding.TextUnmarshaler) + if !ok { + return fmt.Errorf("can't cast %v into encoding.TextUnmarshaler", fieldRef) + } + newval := textUnmarshalerType{ + val: v, + } + if hasDefaultTag { + err := newval.Set(tagDefault) + if err != nil { + return fmt.Errorf("failed to parse default value into %v: %w", reflect.TypeOf(fieldRef), err) + } + } + flagSet.Var(&newval, renamed, usage) + if aliases != "" { + for _, alias := range strings.Split(aliases, ",") { + flagSet.Var(&newval, alias, usage) + } + } + return nil + +}