diff --git a/.travis.yml b/.travis.yml index 1f41717..fde1cdb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,8 @@ language: go go: - - 1.5 + - 1.9.x + - 1.10.x + - 1.11.x notifications: email: false sudo: false diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ab9007c --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/mreiferson/go-options + +go 1.11 diff --git a/options.go b/options.go index 3ae10f1..ea6efc9 100644 --- a/options.go +++ b/options.go @@ -7,7 +7,6 @@ import ( "fmt" "log" "reflect" - "regexp" "strconv" "strings" "time" @@ -84,9 +83,9 @@ func Resolve(options interface{}, flagSet *flag.FlagSet, cfg map[string]interfac // resolve the flags according to priority var v interface{} if hasArg(flagSet, flagName) { - v = flagInst.Value.String() + v = flagInst.Value.(flag.Getter).Get() } else if deprecatedFlagName != "" && hasArg(flagSet, deprecatedFlagName) { - v = deprecatedFlag.Value.String() + v = deprecatedFlag.Value.(flag.Getter).Get() log.Printf("WARNING: use of the --%s command line flag is deprecated (use --%s)", deprecatedFlagName, flagName) } else if cfgVal, ok := cfg[cfgName]; ok { @@ -95,24 +94,25 @@ func Resolve(options interface{}, flagSet *flag.FlagSet, cfg map[string]interfac // if the type has a Get() method, use that as the default value v = getter.Get() } else { - // otherwise, use the default value + // otherwise, use the struct's default value v = val.Field(i).Interface() } fieldVal := val.FieldByName(field.Name) - coerced, err := coerce(v, fieldVal.Interface(), field.Tag.Get("arg")) - if err != nil { - log.Fatalf("ERROR: option resolution failed to coerce %v for %s (%+v) - %s", - v, field.Name, fieldVal, err) + if fieldVal.Type() != reflect.TypeOf(v) { + newv, err := coerce(v, fieldVal.Interface()) + if err != nil { + log.Fatalf("ERROR: Resolve failed to coerce value %v (%+v) for field %s - %s", + v, fieldVal, field.Name, err) + } + v = newv } - fieldVal.Set(reflect.ValueOf(coerced)) + fieldVal.Set(reflect.ValueOf(v)) } } func coerceBool(v interface{}) (bool, error) { switch v.(type) { - case bool: - return v.(bool), nil case string: return strconv.ParseBool(v.(string)) case int, int16, uint16, int32, uint32, int64, uint64: @@ -143,28 +143,13 @@ func coerceFloat64(v interface{}) (float64, error) { return 0, fmt.Errorf("invalid float64 value type %T", v) } -func coerceDuration(v interface{}, arg string) (time.Duration, error) { +func coerceDuration(v interface{}) (time.Duration, error) { switch v.(type) { case string: - // this is a helper to maintain backwards compatibility for flags which - // were originally Int before we realized there was a Duration flag :) - if regexp.MustCompile(`^[0-9]+$`).MatchString(v.(string)) { - intVal, err := strconv.Atoi(v.(string)) - if err != nil { - return 0, err - } - mult, err := time.ParseDuration(arg) - if err != nil { - return 0, err - } - return time.Duration(intVal) * mult, nil - } return time.ParseDuration(v.(string)) case int, int16, uint16, int32, uint32, int64, uint64: // treat like ms return time.Duration(reflect.ValueOf(v).Int()) * time.Millisecond, nil - case time.Duration: - return v.(time.Duration), nil } return 0, fmt.Errorf("invalid time.Duration value type %T", v) } @@ -180,8 +165,6 @@ func coerceStringSlice(v interface{}) ([]string, error) { for _, si := range v.([]interface{}) { tmp = append(tmp, si.(string)) } - case []string: - tmp = v.([]string) } return tmp, nil } @@ -209,21 +192,15 @@ func coerceFloat64Slice(v interface{}) ([]float64, error) { } tmp = append(tmp, f) } - case []float64: - tmp = v.([]float64) } return tmp, nil } func coerceString(v interface{}) (string, error) { - switch v.(type) { - case string: - return v.(string), nil - } return fmt.Sprintf("%s", v), nil } -func coerce(v interface{}, opt interface{}, arg string) (interface{}, error) { +func coerce(v interface{}, opt interface{}) (interface{}, error) { switch opt.(type) { case bool: return coerceBool(v) @@ -265,6 +242,12 @@ func coerce(v interface{}, opt interface{}, arg string) (interface{}, error) { return nil, err } return uint64(i), nil + case float32: + i, err := coerceFloat64(v) + if err != nil { + return nil, err + } + return float32(i), nil case float64: i, err := coerceFloat64(v) if err != nil { @@ -274,7 +257,7 @@ func coerce(v interface{}, opt interface{}, arg string) (interface{}, error) { case string: return coerceString(v) case time.Duration: - return coerceDuration(v, arg) + return coerceDuration(v) case []string: return coerceStringSlice(v) case []float64: