Skip to content

Commit

Permalink
random updates:
Browse files Browse the repository at this point in the history
* use Get() rather than String() for flag values
* don't coerce values of the same type (enables support for custom flag
  types)
* remove nsq-specific (deprecated) Duration handling
  • Loading branch information
mreiferson committed Mar 2, 2019
1 parent 5fcea60 commit 0c63f02
Showing 1 changed file with 20 additions and 37 deletions.
57 changes: 20 additions & 37 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"log"
"reflect"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -84,9 +83,9 @@ func Resolve(options interface{}, flagSet *flag.FlagSet, cfg map[string]interfac
// resolve the flags according to priority
var v interface{}
if hasArg(flagSet, flagName) {
v = flagInst.Value.String()
v = flagInst.Value.(flag.Getter).Get()
} else if deprecatedFlagName != "" && hasArg(flagSet, deprecatedFlagName) {
v = deprecatedFlag.Value.String()
v = deprecatedFlag.Value.(flag.Getter).Get()
log.Printf("WARNING: use of the --%s command line flag is deprecated (use --%s)",
deprecatedFlagName, flagName)
} else if cfgVal, ok := cfg[cfgName]; ok {
Expand All @@ -95,24 +94,25 @@ func Resolve(options interface{}, flagSet *flag.FlagSet, cfg map[string]interfac
// if the type has a Get() method, use that as the default value
v = getter.Get()
} else {
// otherwise, use the default value
// otherwise, use the struct's default value
v = val.Field(i).Interface()
}

fieldVal := val.FieldByName(field.Name)
coerced, err := coerce(v, fieldVal.Interface(), field.Tag.Get("arg"))
if err != nil {
log.Fatalf("ERROR: option resolution failed to coerce %v for %s (%+v) - %s",
v, field.Name, fieldVal, err)
if fieldVal.Type() != reflect.TypeOf(v) {
newv, err := coerce(v, fieldVal.Interface())
if err != nil {
log.Fatalf("ERROR: Resolve failed to coerce value %v (%+v) for field %s - %s",
v, fieldVal, field.Name, err)
}
v = newv
}
fieldVal.Set(reflect.ValueOf(coerced))
fieldVal.Set(reflect.ValueOf(v))
}
}

func coerceBool(v interface{}) (bool, error) {
switch v.(type) {
case bool:
return v.(bool), nil
case string:
return strconv.ParseBool(v.(string))
case int, int16, uint16, int32, uint32, int64, uint64:
Expand Down Expand Up @@ -143,28 +143,13 @@ func coerceFloat64(v interface{}) (float64, error) {
return 0, fmt.Errorf("invalid float64 value type %T", v)
}

func coerceDuration(v interface{}, arg string) (time.Duration, error) {
func coerceDuration(v interface{}) (time.Duration, error) {
switch v.(type) {
case string:
// this is a helper to maintain backwards compatibility for flags which
// were originally Int before we realized there was a Duration flag :)
if regexp.MustCompile(`^[0-9]+$`).MatchString(v.(string)) {
intVal, err := strconv.Atoi(v.(string))
if err != nil {
return 0, err
}
mult, err := time.ParseDuration(arg)
if err != nil {
return 0, err
}
return time.Duration(intVal) * mult, nil
}
return time.ParseDuration(v.(string))
case int, int16, uint16, int32, uint32, int64, uint64:
// treat like ms
return time.Duration(reflect.ValueOf(v).Int()) * time.Millisecond, nil
case time.Duration:
return v.(time.Duration), nil
}
return 0, fmt.Errorf("invalid time.Duration value type %T", v)
}
Expand All @@ -180,8 +165,6 @@ func coerceStringSlice(v interface{}) ([]string, error) {
for _, si := range v.([]interface{}) {
tmp = append(tmp, si.(string))
}
case []string:
tmp = v.([]string)
}
return tmp, nil
}
Expand Down Expand Up @@ -209,21 +192,15 @@ func coerceFloat64Slice(v interface{}) ([]float64, error) {
}
tmp = append(tmp, f)
}
case []float64:
tmp = v.([]float64)
}
return tmp, nil
}

func coerceString(v interface{}) (string, error) {
switch v.(type) {
case string:
return v.(string), nil
}
return fmt.Sprintf("%s", v), nil
}

func coerce(v interface{}, opt interface{}, arg string) (interface{}, error) {
func coerce(v interface{}, opt interface{}) (interface{}, error) {
switch opt.(type) {
case bool:
return coerceBool(v)
Expand Down Expand Up @@ -265,6 +242,12 @@ func coerce(v interface{}, opt interface{}, arg string) (interface{}, error) {
return nil, err
}
return uint64(i), nil
case float32:
i, err := coerceFloat64(v)
if err != nil {
return nil, err
}
return float32(i), nil
case float64:
i, err := coerceFloat64(v)
if err != nil {
Expand All @@ -274,7 +257,7 @@ func coerce(v interface{}, opt interface{}, arg string) (interface{}, error) {
case string:
return coerceString(v)
case time.Duration:
return coerceDuration(v, arg)
return coerceDuration(v)
case []string:
return coerceStringSlice(v)
case []float64:
Expand Down

0 comments on commit 0c63f02

Please sign in to comment.