Skip to content

Commit

Permalink
Add support for types with encoding.TextUnmarshaler (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
hujun-open authored Jun 2, 2023
1 parent 1824349 commit 80502ef
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 23 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ import "github.com/itzg/go-flagsfiller"
- `net.IP` parse via net.ParseIP()
- `net.IPNet` parse via net.ParseCIDR()
- `net.HardwareAddr` parse via net.ParseMAC()
- and all types that implement encoding.TextUnmarshaler interface
- Optionally set flag values from environment variables. Similar to flag names, environment variable names are derived automatically from the field names
- New types could be supported via user code, via `RegisterSimpleType(ConvertFunc)`, check [time.go](time.go) and [net.go](net.go) to see how it works
- note: in case of a registered type also implements encoding.TextUnmarshaler, then registered type's ConvertFunc is preferred

## Quick example

Expand Down
35 changes: 33 additions & 2 deletions addtional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package flagsfiller_test
import (
"flag"
"net"
"net/netip"
"testing"
"time"

Expand All @@ -13,7 +14,7 @@ import (

func TestTime(t *testing.T) {
type Config struct {
T time.Time `layout:"2006-Jan-02==15:04:05"`
T time.Time `default:"2010-Oct-01==10:02:03" layout:"2006-Jan-02==15:04:05"`
}

var config Config
Expand All @@ -24,9 +25,15 @@ func TestTime(t *testing.T) {
err := filler.Fill(&flagset, &config)
require.NoError(t, err)

//test default tag
err = flagset.Parse([]string{})
require.NoError(t, err)
expeted, _ := time.Parse("2006-Jan-02==15:04:05", "2010-Oct-01==10:02:03")
assert.Equal(t, expeted, config.T)

err = flagset.Parse([]string{"-t", "2016-Dec-13==16:03:02"})
require.NoError(t, err)
expeted, _ := time.Parse("2006-01-02 15:04:05", "2016-12-13 16:03:02")
expeted, _ = time.Parse("2006-01-02 15:04:05", "2016-12-13 16:03:02")
assert.Equal(t, expeted, config.T)
}

Expand Down Expand Up @@ -86,3 +93,27 @@ func TestIPNet(t *testing.T) {
_, expected, _ := net.ParseCIDR("192.168.1.0/24")
assert.Equal(t, *expected, config.Prefix)
}

func TestTextUnmarshalerType(t *testing.T) {
type Config struct {
Addr netip.Addr `default:"9.9.9.9"`
}

var config Config

filler := flagsfiller.New()

var flagset flag.FlagSet
err := filler.Fill(&flagset, &config)
require.NoError(t, err)

//test default tag
err = flagset.Parse([]string{})
require.NoError(t, err)
assert.Equal(t, netip.AddrFrom4([4]byte{9, 9, 9, 9}), config.Addr)

err = flagset.Parse([]string{"-addr", "1.2.3.4"})
require.NoError(t, err)

assert.Equal(t, netip.AddrFrom4([4]byte{1, 2, 3, 4}), config.Addr)
}
58 changes: 39 additions & 19 deletions flagset.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package flagsfiller

import (
"encoding"
"flag"
"fmt"
"os"
Expand Down Expand Up @@ -55,13 +56,28 @@ func (f *FlagSetFiller) Fill(flagSet *flag.FlagSet, from interface{}) error {
}
}

func isSupportedStruct(name string) bool {
_, ok := extendedTypes[name]
return ok
func isSupportedStruct(in any) bool {
t := reflect.TypeOf(in)
_, ok := extendedTypes[getTypeName(t)]
if ok {
return true
}
if t.Kind() != reflect.Pointer {
val := reflect.ValueOf(in)
t = val.Addr().Type()
}
if t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) {
RegisterTextUnmarshaler(in)
return true
}
return false
}

func getTypeName(t reflect.Type) string {
return t.PkgPath() + "." + t.Name()
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
return fmt.Sprint(t)
}

func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string,
Expand Down Expand Up @@ -97,13 +113,15 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string,

switch field.Type.Kind() {
case reflect.Struct:
fieldTypeName := getTypeName(field.Type)
if isSupportedStruct(fieldTypeName) {
err := handleDefault(field, fieldValue)
if err != nil {
return err
// fieldTypeName := getTypeName(field.Type)
if field.IsExported() {
if isSupportedStruct(fieldValue.Addr().Interface()) {
err := handleDefault(field, fieldValue)
if err != nil {
return err
}
continue
}
continue
}
err := f.walkFields(flagSet, prefix+field.Name, fieldValue, field.Type)
if err != nil {
Expand All @@ -112,17 +130,19 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string,

case reflect.Ptr:
if fieldValue.CanSet() && field.Type.Elem().Kind() == reflect.Struct {
fieldTypeName := getTypeName(field.Type.Elem())
// fieldTypeName := getTypeName(field.Type.Elem())
// fill the pointer with a new struct of their type if it is nil
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.Type.Elem()))
}
if isSupportedStruct(fieldTypeName) {
err := handleDefault(field, fieldValue.Elem())
if err != nil {
return err
if field.IsExported() {
if isSupportedStruct(fieldValue.Interface()) {
err := handleDefault(field, fieldValue.Elem())
if err != nil {
return err
}
continue
}
continue
}

err := f.walkFields(flagSet, field.Name, fieldValue.Elem(), field.Type.Elem())
Expand Down Expand Up @@ -175,11 +195,11 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}
} else {
renamed = f.options.renameLongName(name)
}
typeName := getTypeName(t)

// go through all supported structs
if handler, ok := extendedTypes[typeName]; ok {
if isSupportedStruct(fieldRef) {
handler := extendedTypes[getTypeName(t)]
err = handler(tag, fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases)

}

switch {
Expand Down
1 change: 0 additions & 1 deletion general.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ func processGeneral[T any](fieldRef interface{}, val flagVal[T],
hasDefaultTag bool, tagDefault string,
flagSet *flag.FlagSet, renamed string,
usage string, aliases string) (err error) {

casted := fieldRef.(*T)
if hasDefaultTag {
*casted, err = val.StrConverter(tagDefault)
Expand Down
2 changes: 1 addition & 1 deletion simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (v *simpleType[T]) String() string {
if v.val == nil {
return fmt.Sprint(nil)
}
return fmt.Sprintf("%v", *v.val)
return fmt.Sprint(v.val)
}

func (v *simpleType[T]) StrConverter(s string) (T, error) {
Expand Down
60 changes: 60 additions & 0 deletions txtunmarshaler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// This file implements support for all types that support interface encoding.TextUnmarshaler
package flagsfiller

import (
"encoding"
"flag"
"fmt"
"reflect"
"strings"
)

// RegisterTextUnmarshaler use is optional, since flagsfiller will automatically register the types implement encoding.TextUnmarshaler it encounters
func RegisterTextUnmarshaler(in any) {
base := textUnmarshalerType{}
extendedTypes[getTypeName(reflect.TypeOf(in).Elem())] = base.process
}

type textUnmarshalerType struct {
val encoding.TextUnmarshaler
}

// String implements flag.Value interface
func (tv *textUnmarshalerType) String() string {
if tv.val == nil {
return fmt.Sprint(nil)
}
return fmt.Sprint(tv.val)
}

// Set implements flag.Value interface
func (tv *textUnmarshalerType) Set(s string) error {
return tv.val.UnmarshalText([]byte(s))
}

func (tv *textUnmarshalerType) process(tag reflect.StructTag, fieldRef interface{},
hasDefaultTag bool, tagDefault string,
flagSet *flag.FlagSet, renamed string,
usage string, aliases string) error {
v, ok := fieldRef.(encoding.TextUnmarshaler)
if !ok {
return fmt.Errorf("can't cast %v into encoding.TextUnmarshaler", fieldRef)
}
newval := textUnmarshalerType{
val: v,
}
if hasDefaultTag {
err := newval.Set(tagDefault)
if err != nil {
return fmt.Errorf("failed to parse default value into %v: %w", reflect.TypeOf(fieldRef), err)
}
}
flagSet.Var(&newval, renamed, usage)
if aliases != "" {
for _, alias := range strings.Split(aliases, ",") {
flagSet.Var(&newval, alias, usage)
}
}
return nil

}

0 comments on commit 80502ef

Please sign in to comment.