diff --git a/auth/token.go b/auth/token.go index f91185ae..81c285cf 100644 --- a/auth/token.go +++ b/auth/token.go @@ -2,7 +2,13 @@ package auth import ( "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" "math/big" + "strings" ) var ( @@ -24,6 +30,64 @@ func randIntn(n int) int { return int(res.Int64()) } +// Convert a Token to its hashed representation. +func HashToken(s string, salt []byte) (string, error) { + saltHex := fmt.Sprintf("%x", salt) + sha := sha256.New() + _, err := sha.Write(salt) + if err != nil { + return "", err + } + _, err = sha.Write([]byte(s)) + if err != nil { + return "", err + } + hashed := fmt.Sprintf("%x", sha.Sum(nil)) + return fmt.Sprintf("%s$%s", saltHex, hashed), nil +} + +// CompareToken compares a token with a hashed representation, optionally upgrading the hash if necessary. +func CompareToken(s string, hashed string) (bool, *string, error) { + if len(s) != randomTokenLength+1 /* prefix */ { + return false, nil, errors.New("invalid token length") + } + + split := strings.SplitN(hashed, "$", 2) + + // determine if we need to upgrade the hash + if len(split) == 1 { + match := s == hashed + if match { + var salt [16]byte + _, err := io.ReadFull(randReader, salt[:]) + if err != nil { + return false, nil, err + } + hashed, err := HashToken(s, salt[:]) + if err != nil { + return false, nil, err + } + return true, &hashed, nil + } else { + return false, nil, nil + } + } + + if len(split) == 2 { + salt, err := hex.DecodeString(split[0]) + if err != nil { + return false, nil, err + } + inputHashed, err := HashToken(s, salt) + if err != nil { + return false, nil, err + } + return inputHashed == hashed, nil, nil + } + + return false, nil, errors.New("invalid hash format") +} + // GenerateNotExistingToken receives a token generation func and a func to check whether the token exists, returns a unique token. func GenerateNotExistingToken(generateToken func() string, tokenExists func(token string) bool) string { for { diff --git a/auth/token_test.go b/auth/token_test.go index 8ad21c7c..58796465 100644 --- a/auth/token_test.go +++ b/auth/token_test.go @@ -19,6 +19,46 @@ func TestTokenHavePrefix(t *testing.T) { } } +func TestHashTokenStable(t *testing.T) { + salt1 := []byte("salt") + salt2 := []byte("pepper") + seen := make(map[string]bool) + for _, plain := range []string{"", "a", "b", "c", "a\x00", "a\n"} { + hash1, err := HashToken(plain, salt1) + assert.NoError(t, err) + hash1Again, err := HashToken(plain, salt1) + assert.NoError(t, err) + assert.Equal(t, hash1, hash1Again) + hash2, err := HashToken(plain, salt2) + assert.NoError(t, err) + hash2Again, err := HashToken(plain, salt2) + assert.NoError(t, err) + assert.Equal(t, hash2, hash2Again) + + assert.NotEqual(t, hash1, hash2) + assert.False(t, seen[hash1]) + assert.False(t, seen[hash2]) + seen[hash1] = true + seen[hash2] = true + } +} + +func TestCompareToken(t *testing.T) { + salt := []byte("salt") + tokenPlain := GenerateApplicationToken() + hashed, err := HashToken(tokenPlain, salt) + assert.NoError(t, err) + cmpPlain, upgPlain, err := CompareToken(tokenPlain, tokenPlain) + assert.NoError(t, err) + assert.True(t, cmpPlain) + assert.NotEmpty(t, *upgPlain) + + cmpHashed, upgHashed, err := CompareToken(tokenPlain, hashed) + assert.NoError(t, err) + assert.True(t, cmpHashed) + assert.Nil(t, upgHashed) +} + func TestGenerateNotExistingToken(t *testing.T) { count := 5 token := GenerateNotExistingToken(func() string {