From 429f00d4495254c779e3e8de1514a1abb4e0010f Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Thu, 21 Jul 2022 23:44:10 +0200 Subject: [PATCH] Small fixes and improvements (#130) --- aconfig.go | 32 ++-- aconfig_test.go | 378 +++++++++++++++++------------------------------- reflection.go | 16 +- utils.go | 25 ++-- 4 files changed, 169 insertions(+), 282 deletions(-) diff --git a/aconfig.go b/aconfig.go index 2570b2a..39f2e60 100644 --- a/aconfig.go +++ b/aconfig.go @@ -9,17 +9,6 @@ import ( "strings" ) -const ( - defaultValueTag = "default" - usageTag = "usage" - jsonNameTag = "json" - yamlNameTag = "yaml" - tomlNameTag = "toml" - hclNameTag = "hcl" - envNameTag = "env" - flagNameTag = "flag" -) - // Loader of user configuration. type Loader struct { config Config @@ -85,6 +74,10 @@ type Config struct { // Files from which config should be loaded. Files []string + // Envs hold the environment variable from which envs will be parsed. + // By default is nil and then os.Environ() will be used. + Envs []string + // Args hold the command-line arguments from which flags will be parsed. // By default is nil and then os.Args will be used. // Unless loader.Flags() will be explicitly parsed by the user. @@ -164,6 +157,9 @@ func (l *Loader) init() { dec.Init(l.fsys) } + if l.config.Envs == nil { + l.config.Envs = os.Environ() + } if l.config.Args == nil { l.config.Args = os.Args[1:] } @@ -174,7 +170,7 @@ func (l *Loader) init() { if !l.config.SkipFlags { names := make(map[string]bool, len(l.fields)) for _, field := range l.fields { - flagName := l.fullTag(l.config.FlagPrefix, field, flagNameTag) + flagName := l.fullTag(l.config.FlagPrefix, field, "flag") if flagName == "" { continue } @@ -183,7 +179,7 @@ func (l *Loader) init() { return } names[flagName] = true - l.flagSet.String(flagName, field.Tag(defaultValueTag), field.Tag(usageTag)) + l.flagSet.String(flagName, field.Tag("default"), field.Tag("usage")) } } if l.config.FileFlag != "" { @@ -278,7 +274,7 @@ func (l *Loader) checkRequired() error { func (l *Loader) loadDefaults() error { for _, field := range l.fields { - defaultValue := field.Tag(defaultValueTag) + defaultValue := field.Tag("default") if err := l.setFieldData(field, defaultValue); err != nil { return err } @@ -376,11 +372,11 @@ func (l *Loader) loadFileFlag() error { } func (l *Loader) loadEnvironment() error { - actualEnvs := getEnv() + actualEnvs := getEnv(l.config.Envs) dupls := make(map[string]struct{}) for _, field := range l.fields { - envName := l.fullTag(l.config.EnvPrefix, field, envNameTag) + envName := l.fullTag(l.config.EnvPrefix, field, "env") if envName == "" { continue } @@ -411,7 +407,7 @@ func (l *Loader) loadFlags() error { dupls := make(map[string]struct{}) for _, field := range l.fields { - flagName := l.fullTag(l.config.FlagPrefix, field, flagNameTag) + flagName := l.fullTag(l.config.FlagPrefix, field, "flag") if flagName == "" { continue } @@ -430,7 +426,7 @@ func (l *Loader) postFlagCheck(values map[string]interface{}, dupls map[string]s delete(values, name) } for flag, value := range values { - if strings.HasPrefix(flag, l.config.EnvPrefix) { + if strings.HasPrefix(flag, l.config.FlagPrefix) { return fmt.Errorf("unknown flag %s=%v (see AllowUnknownFlags config param)", flag, value) } } diff --git a/aconfig_test.go b/aconfig_test.go index 2d130c3..622fa67 100644 --- a/aconfig_test.go +++ b/aconfig_test.go @@ -3,7 +3,7 @@ package aconfig import ( "embed" "fmt" - "io/ioutil" + "io" "net/url" "os" "reflect" @@ -27,7 +27,6 @@ func (l *LogLevel) UnmarshalText(text []byte) error { default: return fmt.Errorf("unknown log level: %s", text) } - return nil } @@ -38,9 +37,7 @@ func TestDefaults(t *testing.T) { SkipEnv: true, SkipFlags: true, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := TestConfig{ Str: "str-def", @@ -63,10 +60,7 @@ func TestDefaults(t *testing.T) { Em: "em-def", }, } - - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestDefaults_AllTypes(t *testing.T) { @@ -101,9 +95,7 @@ func TestDefaults_AllTypes(t *testing.T) { SkipEnv: true, SkipFlags: true, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := AllTypesConfig{ Bool: true, @@ -121,14 +113,11 @@ func TestDefaults_AllTypes(t *testing.T) { Float32: 1234.213, Float64: 1234.234, Dur: time.Hour + 2*time.Minute + 3*time.Second, - // TODO + // TODO: support time // Time :2000-04-05 10:20:30 +0000 UTC, Level: LogLevel(1), } - - if got := cfg; got != want { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestDefaults_OtherNumberFormats(t *testing.T) { @@ -150,9 +139,7 @@ func TestDefaults_OtherNumberFormats(t *testing.T) { SkipEnv: true, SkipFlags: true, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := OtherNumberFormats{ Int: 7, @@ -165,10 +152,7 @@ func TestDefaults_OtherNumberFormats(t *testing.T) { Uint16: 83, Uint32: 291, } - - if got := cfg; got != want { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestJSON(t *testing.T) { @@ -181,14 +165,10 @@ func TestJSON(t *testing.T) { SkipFlags: true, Files: []string{filepath}, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := wantConfig - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestJSONWithOmitempty(t *testing.T) { @@ -204,9 +184,7 @@ func TestJSONWithOmitempty(t *testing.T) { AllowUnknownFields: true, Files: []string{createTestFile(t)}, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) } func TestCustomFile(t *testing.T) { @@ -222,14 +200,10 @@ func TestCustomFile(t *testing.T) { ".config": &jsonDecoder{}, }, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := wantConfig - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestFile(t *testing.T) { @@ -242,9 +216,7 @@ func TestFile(t *testing.T) { SkipFlags: true, Files: []string{filepath}, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := TestConfig{ Str: "str-json", @@ -260,10 +232,7 @@ func TestFile(t *testing.T) { IsAnon: true, }, } - - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } //go:embed testdata @@ -280,9 +249,7 @@ func TestFileEmbed(t *testing.T) { Files: []string{filepath}, FileSystem: configEmbed, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := TestConfig{ Str: "str-json", @@ -298,10 +265,7 @@ func TestFileEmbed(t *testing.T) { IsAnon: true, }, } - - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestFileMerging(t *testing.T) { @@ -317,9 +281,7 @@ func TestFileMerging(t *testing.T) { MergeFiles: true, Files: []string{file1, file2, file3}, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := TestConfig{ Str: "111", @@ -328,10 +290,7 @@ func TestFileMerging(t *testing.T) { Float: 333.333, }, } - - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestFileFlag(t *testing.T) { @@ -350,18 +309,13 @@ func TestFileFlag(t *testing.T) { Files: []string{file1}, Args: flags, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := TestConfig{ Str: "111", HTTPPort: 222, } - - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestBadFileFlag(t *testing.T) { @@ -376,9 +330,7 @@ func TestBadFileFlag(t *testing.T) { FileFlag: "file_flag", Args: flags, }) - if err := loader.Load(); err == nil { - t.Fatal("should be an error") - } + failIfOk(t, loader.Load()) } func TestNoFileFlagValue(t *testing.T) { @@ -392,18 +344,13 @@ func TestNoFileFlagValue(t *testing.T) { Files: []string{file1}, Args: []string{}, // no file_flag }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := TestConfig{ Str: "111", HTTPPort: 111, } - - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestEnv(t *testing.T) { @@ -423,9 +370,7 @@ func TestEnv(t *testing.T) { SkipFlags: true, EnvPrefix: "TST", }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := TestConfig{ Str: "str-env", @@ -445,9 +390,7 @@ func TestEnv(t *testing.T) { }, } - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestFlag(t *testing.T) { @@ -469,13 +412,9 @@ func TestFlag(t *testing.T) { "-tst.em=em-flag", } - if err := loader.Flags().Parse(flags); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Flags().Parse(flags)) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := TestConfig{ Str: "str-flag", @@ -495,9 +434,7 @@ func TestFlag(t *testing.T) { }, } - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestExactName(t *testing.T) { @@ -521,9 +458,7 @@ func TestExactName(t *testing.T) { AllowUnknownEnvs: true, EnvPrefix: "TST", }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := ExactConfig{ Foo: Foo{ @@ -532,9 +467,7 @@ func TestExactName(t *testing.T) { Bar: "bar-env", } - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestSkipName(t *testing.T) { @@ -555,9 +488,7 @@ func TestSkipName(t *testing.T) { SkipFiles: true, SkipFlags: true, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := ExactConfig{ Foo: Foo{ @@ -566,9 +497,7 @@ func TestSkipName(t *testing.T) { Bar: "def", } - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestDuplicatedName(t *testing.T) { @@ -588,9 +517,7 @@ func TestDuplicatedName(t *testing.T) { SkipFlags: true, AllowDuplicates: true, }) - if err := loader.Load(); err != nil { - t.Error(err) - } + failIfErr(t, loader.Load()) want := ExactConfig{ Foo: Foo{ @@ -599,9 +526,7 @@ func TestDuplicatedName(t *testing.T) { FooBar: "str-env", } - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestFailOnDuplicatedName(t *testing.T) { @@ -619,6 +544,8 @@ func TestFailOnDuplicatedName(t *testing.T) { }) err := loader.Load() + failIfOk(t, err) + if !strings.Contains(err.Error(), "is duplicated") { t.Fatalf("got %s", err.Error()) } @@ -631,11 +558,10 @@ func TestFailOnDuplicatedFlag(t *testing.T) { } err := LoaderFor(&Foo{}, Config{}).Load() + failIfOk(t, err) want := `init loader: duplicate flag "yes"` - if got := err.Error(); got != want { - t.Fatalf("got %s want %s", got, want) - } + mustEqual(t, err.Error(), want) } func TestUsage(t *testing.T) { @@ -651,9 +577,7 @@ func TestUsage(t *testing.T) { use... em...field. (default "em-def") ` - if got != want { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, got, want) } func TestBadDefauts(t *testing.T) { @@ -665,9 +589,7 @@ func TestBadDefauts(t *testing.T) { SkipEnv: true, SkipFlags: true, }) - if err := loader.Load(); err == nil { - t.Fatal(err) - } + failIfOk(t, loader.Load()) } f(&struct { @@ -757,27 +679,21 @@ func TestBadFiles(t *testing.T) { FailOnFileNotFound: true, Files: []string{filepath}, }) - if err := loader.Load(); err == nil { - t.Fatal(err) - } + failIfOk(t, loader.Load()) } - t.Run("no_such_file.json", func(t *testing.T) { + t.Run("no_such_file.json", func(*testing.T) { f("no_such_file.json") }) t.Run("bad_config.json", func(t *testing.T) { filepath := t.TempDir() + "unknown.ext" file, err := os.Create(filepath) - if err != nil { - t.Fatal(err) - } + failIfErr(t, err) defer file.Close() _, err = file.WriteString(`{almost": "json`) - if err != nil { - t.Fatal(err) - } + failIfErr(t, err) f(filepath) }) @@ -785,9 +701,7 @@ func TestBadFiles(t *testing.T) { t.Run("unknown.ext", func(t *testing.T) { filepath := t.TempDir() + "unknown.ext" file, err := os.Create(filepath) - if err != nil { - t.Fatal(err) - } + failIfErr(t, err) defer file.Close() f(filepath) @@ -806,9 +720,7 @@ func TestFailOnFileNotFound(t *testing.T) { Files: []string{filepath}, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) } f("testdata/config.json") @@ -826,9 +738,7 @@ func TestBadEnvs(t *testing.T) { EnvPrefix: "TST", }) - if err := loader.Load(); err == nil { - t.Fatal(err) - } + failIfOk(t, loader.Load()) } func TestBadFlags(t *testing.T) { @@ -839,12 +749,10 @@ func TestBadFlags(t *testing.T) { FlagPrefix: "tst", }) - if err := loader.Flags().Parse([]string{"-tst.param=10a01"}); err != nil { - t.Fatal(err) - } - if err := loader.Load(); err == nil { - t.Fatal(err) - } + args := []string{"-tst.param=10a01"} + + failIfErr(t, loader.Flags().Parse(args)) + failIfOk(t, loader.Load()) } func TestUnknownFields(t *testing.T) { @@ -859,9 +767,8 @@ func TestUnknownFields(t *testing.T) { }) err := loader.Load() - if err == nil { - t.Fatal("must not be nil") - } + failIfOk(t, err) + if !strings.Contains(err.Error(), "unknown field in file") { t.Fatalf("got %s", err.Error()) } @@ -882,9 +789,8 @@ func TestUnknownEnvs(t *testing.T) { }) err := loader.Load() - if err == nil { - t.Fatal("must not be nil") - } + failIfOk(t, err) + if !strings.Contains(err.Error(), "unknown environment var") { t.Fatalf("got %s", err.Error()) } @@ -902,9 +808,7 @@ func TestUnknownEnvsWithEmptyPrefix(t *testing.T) { SkipFlags: true, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) } func TestUnknownFlags(t *testing.T) { @@ -923,20 +827,17 @@ func TestUnknownFlags(t *testing.T) { // just for tests flagSet := loader.Flags() - flagSet.SetOutput(ioutil.Discard) + flagSet.SetOutput(io.Discard) // define flag with a loader's prefix which is unknown flagSet.Int("tst.unknown", 42, "") flagSet.String("just_env", "just_def", "") - if err := flagSet.Parse(flags); err != nil { - t.Fatal(err) - } + failIfErr(t, flagSet.Parse(flags)) err := loader.Load() - if err == nil { - t.Fatal("must not be nil") - } + failIfOk(t, err) + if !strings.Contains(err.Error(), "unknown flag") { t.Fatalf("got %s", err.Error()) } @@ -956,18 +857,13 @@ func TestUnknownFlagsWithEmptyPrefix(t *testing.T) { // just for tests flagSet := loader.Flags() - flagSet.SetOutput(ioutil.Discard) + flagSet.SetOutput(io.Discard) // define flag with a loader's prefix which is unknown flagSet.Int("unknown", 42, "") - if err := flagSet.Parse(flags); err != nil { - t.Fatal(err) - } - - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, flagSet.Parse(flags)) + failIfErr(t, loader.Load()) } // flag.FlagSet already fails on undefined flag @@ -986,11 +882,26 @@ func TestUnknownFlagsStdlib(t *testing.T) { // just for tests flagSet := loader.Flags() - flagSet.SetOutput(ioutil.Discard) + flagSet.SetOutput(io.Discard) + + failIfOk(t, flagSet.Parse(flags)) +} + +func TestCustomEnvsAndArgs(t *testing.T) { + var cfg TestConfig + loader := LoaderFor(&cfg, Config{ + SkipDefaults: true, + Envs: []string{"PARAM=2"}, + Args: []string{"-str=4"}, + }) - if err := flagSet.Parse(flags); err == nil { - t.Fatal("must not be nil") + failIfErr(t, loader.Load()) + + want := TestConfig{ + Str: "4", + Param: 2, } + mustEqual(t, cfg, want) } func TestCustomNames(t *testing.T) { @@ -1000,28 +911,20 @@ func TestCustomNames(t *testing.T) { C int `default:"-1" env:"three" flag:"four"` } + setEnv(t, "ONE", "1") + setEnv(t, "three", "3") + defer os.Clearenv() + var cfg TestConfig loader := LoaderFor(&cfg, Config{ Args: []string{"-two=2", "-four=4"}, }) - setEnv(t, "ONE", "1") - setEnv(t, "three", "3") - defer os.Clearenv() + failIfErr(t, loader.Load()) - if err := loader.Load(); err != nil { - t.Fatal(err) - } - - if want := 1; cfg.A != want { - t.Errorf("got %#v, want %#v", cfg.A, want) - } - if want := 2; cfg.B != want { - t.Errorf("got %#v, want %#v", cfg.B, want) - } - if want := 4; cfg.C != want { - t.Errorf("got %#v, want %#v", cfg.C, want) - } + mustEqual(t, cfg.A, 1) + mustEqual(t, cfg.B, 2) + mustEqual(t, cfg.C, 4) } func TestDontGenerateTags(t *testing.T) { @@ -1051,7 +954,7 @@ func TestDontGenerateTags(t *testing.T) { DontGenerateTags: true, } LoaderFor(&testConfig{}, cfg).WalkFields(func(f Field) bool { - for _, tag := range []string{"json", "yaml", "toml", "env", "flag"} { + for _, tag := range []string{"json", "yaml", "env", "flag"} { k := f.Name() + "::" + tag if v, ok := want[k]; ok && v != f.Tag(tag) { t.Fatalf("%v: got %v, want %v", tag, f.Tag(tag), v) @@ -1105,26 +1008,18 @@ func TestWalkFields(t *testing.T) { LoaderFor(&TestConfig{}, Config{}).WalkFields(func(f Field) bool { wantFields := fields[i] - if f.Name() != wantFields.Name { - t.Fatalf("got name %v, want %v", f.Name(), wantFields.Name) - } - - if parent, ok := f.Parent(); ok && parent.Name() != wantFields.ParentName { - t.Fatalf("got name %v, want %v", parent.Name(), wantFields.ParentName) - } - if f.Tag("default") != wantFields.DefaultValue { - t.Fatalf("got default %#v, want %#v", f.Tag("default"), wantFields.DefaultValue) - } - if f.Tag("usage") != wantFields.Usage { - t.Fatalf("got usage %#v, want %#v", f.Tag("usage"), wantFields.Usage) + mustEqual(t, f.Name(), wantFields.Name) + mustEqual(t, f.Name(), wantFields.Name) + if parent, ok := f.Parent(); ok { + mustEqual(t, parent.Name(), wantFields.ParentName) } + mustEqual(t, f.Tag("default"), wantFields.DefaultValue) + mustEqual(t, f.Tag("usage"), wantFields.Usage) i++ return true }) - if want := 3; i != want { - t.Fatalf("got %v, want %v", i, want) - } + mustEqual(t, i, 3) i = 0 LoaderFor(&TestConfig{}, Config{}).WalkFields(func(f Field) bool { @@ -1147,9 +1042,7 @@ func TestDontFillFlagsIfDisabled(t *testing.T) { SkipFlags: true, Args: []string{}, }) - if err := loader.Load(); err != nil { - t.Error(err) - } + failIfErr(t, loader.Load()) if flags := loader.Flags().NFlag(); flags != 0 { t.Errorf("want empty, got %v", flags) @@ -1204,9 +1097,7 @@ func TestBadRequiredTag(t *testing.T) { } func setEnv(t *testing.T, key, value string) { - if err := os.Setenv(key, value); err != nil { - t.Fatal(err) - } + failIfErr(t, os.Setenv(key, value)) } func int32Ptr(a int32) *int32 { @@ -1215,9 +1106,7 @@ func int32Ptr(a int32) *int32 { func createTestFile(t *testing.T, name ...string) string { t.Helper() - if len(name) > 1 { - t.Fatal() - } + mustEqual(t, len(name) < 2, true) dir := t.TempDir() t.Cleanup(func() { @@ -1230,14 +1119,10 @@ func createTestFile(t *testing.T, name ...string) string { } f, err := os.Create(filepath) - if err != nil { - t.Fatal(err) - } + failIfErr(t, err) defer f.Close() _, err = f.WriteString(testfileContent) - if err != nil { - t.Fatal(err) - } + failIfErr(t, err) return filepath } @@ -1428,9 +1313,7 @@ func TestSliceStructs(t *testing.T) { Files: []string{"testdata/complex.json"}, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) want := ConfigTest{ VCenter: ConfigVCenter{ @@ -1451,9 +1334,7 @@ func TestSliceStructs(t *testing.T) { }, }, } - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestMapOfMap(t *testing.T) { @@ -1469,20 +1350,16 @@ func TestMapOfMap(t *testing.T) { Files: []string{"testdata/toy.json"}, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) - var want = TestConfig{ + want := TestConfig{ Options: map[string]float64{ "foo": 0.4, "bar": 0.25, }, } - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestBad(t *testing.T) { @@ -1496,21 +1373,17 @@ func TestBad(t *testing.T) { loader := LoaderFor(&cfg, Config{ SkipFlags: true, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) p, err := url.ParseQuery("foo=bar") if err != nil { t.Fatal(err) } - var want = TestConfig{ + want := TestConfig{ Params: p, } - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) - } + mustEqual(t, cfg, want) } func TestFileConfigFlagDelim(t *testing.T) { @@ -1531,16 +1404,33 @@ func TestFileConfigFlagDelim(t *testing.T) { Files: []string{"testdata/toy.json"}, }) - if err := loader.Load(); err != nil { - t.Fatal(err) - } + failIfErr(t, loader.Load()) - var want = TestConfig{Options: struct { + want := TestConfig{Options: struct { Foo float64 Bar float64 }{0.4, 0.25}} - if got := cfg; !reflect.DeepEqual(want, got) { - t.Fatalf("want %v, got %v", want, got) + mustEqual(t, cfg, want) +} + +func failIfOk(t testing.TB, err error) { + t.Helper() + if err == nil { + t.Fatal("must be non-nil") + } +} + +func failIfErr(t testing.TB, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func mustEqual(t testing.TB, got, want interface{}) { + t.Helper() + if !reflect.DeepEqual(got, want) { + t.Fatalf("\nhave %+v\nwant %+v", got, want) } } diff --git a/reflection.go b/reflection.go index 7ea1090..21f72c6 100644 --- a/reflection.go +++ b/reflection.go @@ -60,25 +60,25 @@ func (l *Loader) tagsForField(field reflect.StructField) map[string]string { words := splitNameByWords(field.Name) tags := map[string]string{ - defaultValueTag: field.Tag.Get(defaultValueTag), - usageTag: field.Tag.Get(usageTag), + "default": field.Tag.Get("default"), + "usage": field.Tag.Get("usage"), - envNameTag: l.makeTagValue(field, envNameTag, words), - flagNameTag: l.makeTagValue(field, flagNameTag, words), + "env": l.makeTagValue(field, "env", words), + "flag": l.makeTagValue(field, "flag", words), } - for _, tag := range []string{jsonNameTag, yamlNameTag, tomlNameTag, hclNameTag} { - tags[tag] = l.makeTagValue(field, tag, words) + for _, dec := range l.config.FileDecoders { + tags[dec.Format()] = l.makeTagValue(field, dec.Format(), words) } return tags } func (l *Loader) fullTag(prefix string, f *fieldData, tag string) string { sep := "." - if tag == flagNameTag { + if tag == "flag" { sep = l.config.FlagDelimiter } - if tag == envNameTag { + if tag == "env" { sep = l.config.envDelimiter } res := f.Tag(tag) diff --git a/utils.go b/utils.go index b15f248..9608c90 100644 --- a/utils.go +++ b/utils.go @@ -27,8 +27,7 @@ func assertStruct(x interface{}) { } } -func getEnv() map[string]interface{} { - env := os.Environ() +func getEnv(env []string) map[string]interface{} { res := make(map[string]interface{}, len(env)) for _, s := range env { @@ -73,15 +72,16 @@ func (l *Loader) makeTagValue(field reflect.StructField, tag string, words []str return v } - switch tag { - case jsonNameTag, yamlNameTag, tomlNameTag, hclNameTag: - if l.config.DontGenerateTags { - return field.Name + for _, dec := range l.config.FileDecoders { + if tag == dec.Format() { + if l.config.DontGenerateTags { + return field.Name + } } } name := strings.Join(words, "_") - if tag == envNameTag { + if tag == "env" { return strings.ToUpper(name) } return strings.ToLower(name) @@ -90,7 +90,7 @@ func (l *Loader) makeTagValue(field reflect.StructField, tag string, words []str // based on https://github.com/fatih/camelcase func splitNameByWords(src string) []string { var runes [][]rune - lastClass, class := 0, 0 + var lastClass, class int // split into fields based on class of unicode character for _, r := range src { @@ -132,7 +132,8 @@ func splitNameByWords(src string) []string { } // copy-paste until https://github.com/golang/go/issues/46336 is fixed -func cut(s, sep string) (before, after string, found bool) { +// returns: before, after, isFound +func cut(s, sep string) (_, _ string, _ bool) { if i := strings.Index(s, sep); i >= 0 { return s[:i], s[i+len(sep):], true } @@ -143,11 +144,11 @@ var _ fs.FS = &fsOrOS{} type fsOrOS struct{ fs.FS } -func (fs *fsOrOS) Open(name string) (fs.File, error) { - if fs.FS == nil { +func (f *fsOrOS) Open(name string) (fs.File, error) { + if f.FS == nil { return os.Open(name) } - return fs.FS.Open(name) + return f.FS.Open(name) } type jsonDecoder struct {