diff --git a/passfile/passfile.go b/passfile/passfile.go index ca86a73..3d0a9b5 100644 --- a/passfile/passfile.go +++ b/passfile/passfile.go @@ -104,7 +104,11 @@ func ParseFile(file string) ([]Entry, error) { func (entry Entry) Equals(v Entry, protocols ...string) bool { return (entry.Protocol == "*" || contains(protocols, entry.Protocol)) && (entry.Host == "*" || entry.Host == v.Host) && - (entry.Port == "*" || entry.Port == v.Port) + (entry.Port == "*" || entry.Port == v.Port) && + (entry.DBName == "*" || entry.DBName == v.DBName) && + (entry.Username == v.Username || + (entry.Username != "*" && v.Username == "") || + (entry.Username == "*" && v.Username != "")) } // MatchEntries returns a Userinfo when the normalized v is found in entries. @@ -118,11 +122,12 @@ func MatchEntries(u *dburl.URL, entries []Entry, protocols ...string) (*url.User } } // find matching entry - n := strings.SplitN(u.Normalize(":", "", 3), ":", 6) + n := strings.SplitN(u.Normalize(":", "", 4), ":", 6) if len(n) < 3 { return nil, ErrUnableToNormalizeURL } m := NewEntry(n) + m.Username = username for _, entry := range entries { if entry.Equals(m, protocols...) { u := entry.Username diff --git a/passfile/passfile_test.go b/passfile/passfile_test.go index 91f84fa..7e3fc51 100644 --- a/passfile/passfile_test.go +++ b/passfile/passfile_test.go @@ -4,8 +4,46 @@ import ( "reflect" "strings" "testing" + + "github.com/xo/dburl" ) +func expectedUserPassword(t *testing.T, url string, user, pass string) { + entries, _ := Parse(strings.NewReader(matchdata)) + parsedURL, _ := dburl.Parse(url) + ui, _ := MatchEntries(parsedURL, entries, "postgres") + + if ui.Username() != user { + t.Fatalf("expected user %s, got %s", user, ui.Username()) + } + + url_pass, ok := ui.Password() + + if !ok { + url_pass = "" + } + + if url_pass != pass { + t.Fatalf("expected pass %s, got %s", pass, url_pass) + } +} + +func TestMatching(t *testing.T) { + expectedUserPassword(t, "postgres://user@host:1/db", "user", "pass1") + expectedUserPassword(t, "postgres://user2@host:1/db", "user2", "pass2") + expectedUserPassword(t, "postgres://user@host:2/db", "user", "pass3") + expectedUserPassword(t, "postgres://user2@host:2/db", "user2", "pass4") + expectedUserPassword(t, "postgres://user@host2:1/db", "user", "pass5") + expectedUserPassword(t, "postgres://user2@host2:1/db", "user2", "pass6") + expectedUserPassword(t, "postgres://user@host:1/db2", "user", "pass7") + expectedUserPassword(t, "postgres://user2@host:1/db2", "user2", "pass8") + expectedUserPassword(t, "postgres://user@host:2/db2", "user", "pass7") + expectedUserPassword(t, "postgres://user2@host:2/db2", "user2", "pass10") + expectedUserPassword(t, "postgres://host:1/db", "user", "pass1") + expectedUserPassword(t, "postgres://host:2/db2", "user2", "pass10") + expectedUserPassword(t, "postgres://user3@host:1/db", "user3", "pass11") +} + func TestParse(t *testing.T) { entries, err := Parse(strings.NewReader(passfile)) if err != nil { @@ -47,3 +85,30 @@ pgx:*:*:*:postgres:P4ssw0rd sqlserver:*:*:*:sa:Adm1nP@ssw0rd vertica:*:*:*:dbadmin:P4ssw0rd ` + +const matchdata = ` +#All fields entered +postgres:host:1:db:user:pass1 +postgres:host:1:db:user2:pass2 + +#Any port +postgres:host:*:db:user:pass3 +postgres:host:*:db:user2:pass4 + +#Any host/port +postgres:*:*:db:user:pass5 +postgres:*:*:db:user2:pass6 + +#Order matters (will get here) +postgres:host:2:db2:user2:pass10 + +#Any database +postgres:*:*:*:user:pass7 +postgres:*:*:*:user2:pass8 + +#Order matters (won't get here) +postgres:host:2:db2:user:pass9 + +#Default password +postgres:*:*:*:*:pass11 +`