diff --git a/integration_test.go b/integration_test.go index 5f5c200..12af1e8 100644 --- a/integration_test.go +++ b/integration_test.go @@ -2,6 +2,7 @@ package oauth2_test import ( "context" + "crypto/ecdsa" "fmt" "log" "net" @@ -12,15 +13,23 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" oauth2 "github.com/oxisto/oauth2go" "github.com/oxisto/oauth2go/login" + "github.com/oxisto/oauth2go/storage" + + "github.com/golang-jwt/jwt/v5" "golang.org/x/net/html" "golang.org/x/oauth2/clientcredentials" ) func TestIntegration(t *testing.T) { - srv := oauth2.NewServer(":0", oauth2.WithClient("client", "secret", "")) + srv := oauth2.NewServer( + ":0", + oauth2.WithClient("client", "secret", ""), + oauth2.WithSigningKeysFunc(func() map[int]*ecdsa.PrivateKey { + return storage.LoadSigningKeys("storage/testdata/ecdsa.pem", "changeme", false) + }), + ) ln, err := net.Listen("tcp", srv.Addr) if err != nil { t.Errorf("Error while listening key: %v", err) diff --git a/pkcs/pkcs5.go b/pkcs/pkcs5.go new file mode 100644 index 0000000..1ef786d --- /dev/null +++ b/pkcs/pkcs5.go @@ -0,0 +1,244 @@ +package pkcs + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/pem" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/pbkdf2" +) + +var ( + oidAES128CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 2} + oidHMACWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 9} + oidPBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} + oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} +) + +const DefaultIterations = 10000 + +// PBKDF2Params are parameters for PBKDF2. See +// https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.2. +type PBKDF2Params struct { + Salt []byte + IterationCount int + PRF pkix.AlgorithmIdentifier `asn1:"optional"` +} + +// KeyDerivationFunc is part of PBES2 and specify the key derivation function. +// See https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.4. +type KeyDerivationFunc struct { + Algorithm asn1.ObjectIdentifier + PBKDF2Params PBKDF2Params +} + +// EncryptionScheme is part of PBES2 and specifies the encryption algorithm. See +// https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.4. +type EncryptionScheme struct { + EncryptionAlgorithm asn1.ObjectIdentifier + IV []byte +} + +// PBES2Params are parameters for PBES2. See +// https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.4. +type PBES2Params struct { + KeyDerivationFunc KeyDerivationFunc + EncryptionScheme EncryptionScheme +} + +// EncryptionAlgorithmIdentifier is the identifier for the encryption algorithm. +// See https://datatracker.ietf.org/doc/html/rfc5958#section-3. +type EncryptionAlgorithmIdentifier struct { + Algorithm asn1.ObjectIdentifier + Params PBES2Params +} + +// EncryptedPrivateKeyInfo contains meta-info about the encrypted private key. +// See https://datatracker.ietf.org/doc/html/rfc5958#section-3. +type EncryptedPrivateKeyInfo struct { + EncryptionAlgorithm EncryptionAlgorithmIdentifier + EncryptedData []byte +} + +// MarshalPKCS5PrivateKeyWithPassword marshals an private key protected with a +// password according to PKCS#5 into a byte array +func MarshalPKCS5PrivateKeyWithPassword(key crypto.PrivateKey, password []byte) (data []byte, err error) { + var decrypted []byte + decrypted, err = x509.MarshalPKCS8PrivateKey(key) + if err != nil { + // Directly return error here, because we are basically a wrapper around + // x509.MarshalPKCS8PrivateKey and we want our errors to be similar + return nil, err + } + + block, err := EncryptPEMBlock(rand.Reader, decrypted, password) + if err != nil { + return nil, fmt.Errorf("could not encrypt PEM block: %w", err) + } + + return pem.EncodeToMemory(block), nil +} + +// ParsePKCS5PrivateKeyWithPassword reads a private key protected with a +// password according to PKCS#5 from a byte array. +func ParsePKCS5PrivateKeyWithPassword(data []byte, password []byte) (key crypto.PrivateKey, err error) { + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(data); block == nil { + return nil, errors.New("could not decode PEM") + } + + var decrypted []byte + if decrypted, err = DecryptPEMBlock(block, password); err != nil { + return nil, fmt.Errorf("could not decrypt PEM block: %w", err) + } + + parsedKey, err := x509.ParsePKCS8PrivateKey(decrypted) + if err != nil { + // Directly return error here, because we are basically a wrapper around + // x509.ParsePKCS8PrivateKey and we want our errors to be similar + return nil, err + } else { + // For backwards compatiblity ParsePKCS8PrivateKey does not return a + // crypto.PrivateKey, but "any". However, we can just cast this, since + // crypto.PrivateKey's underlying type is "any". + return (crypto.PrivateKey)(parsedKey), nil + } +} + +// EncryptPEMBlock encrypts a private key contained in data into a PEM block +// according to PKCS#8. +func EncryptPEMBlock(rand io.Reader, data, password []byte) (block *pem.Block, err error) { + // Although we do not do an extended check on the password, we want to + // enforce "any" kind of password, so it should at least not be empty. + if len(password) == 0 { + return nil, errors.New("empty password") + } + + var salt = make([]byte, 8) + if _, err = rand.Read(salt); err != nil { + return nil, fmt.Errorf("error creating salt: %w", err) + } + + var iv = make([]byte, 16) + if _, err = rand.Read(iv); err != nil { + return nil, fmt.Errorf("error creating IV: %w", err) + } + + var pad = 16 - len(data)%16 + + // Build EncryptedPrivateKeyInfo + keyInfo := EncryptedPrivateKeyInfo{ + EncryptionAlgorithm: EncryptionAlgorithmIdentifier{ + Algorithm: oidPBES2, + Params: PBES2Params{ + KeyDerivationFunc: KeyDerivationFunc{ + Algorithm: oidPBKDF2, + PBKDF2Params: PBKDF2Params{ + IterationCount: DefaultIterations, + Salt: salt, + PRF: pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA256, + }, + }, + }, + EncryptionScheme: EncryptionScheme{ + EncryptionAlgorithm: oidAES128CBC, + IV: iv, + }, + }, + }, + EncryptedData: make([]byte, len(data), len(data)+pad), // We will encrypt this later + } + + // Derive key using PBKDF2 + key := pbkdf2.Key( + password, + salt, + keyInfo.EncryptionAlgorithm.Params.KeyDerivationFunc.PBKDF2Params.IterationCount, + 16, + sha256.New, + ) + + // Set up symmetric encryption of our block. We can safely ignore the errors + // here, because the only error which can occur in aes.NewCipher is an + // invalid key size and the above line makes sure we always have a 32 bytes + // key. + cipherBlock, _ := aes.NewCipher(key) + mode := cipher.NewCBCEncrypter(cipherBlock, keyInfo.EncryptionAlgorithm.Params.EncryptionScheme.IV) + + copy(keyInfo.EncryptedData, data) + for i := 0; i < pad; i++ { + keyInfo.EncryptedData = append(keyInfo.EncryptedData, byte(pad)) + } + + mode.CryptBlocks(keyInfo.EncryptedData, keyInfo.EncryptedData) + + block = &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Headers: make(map[string]string), + } + + // Marshal key info into ASN1 format, which is the payload of our PEM block + block.Bytes, err = asn1.Marshal(keyInfo) + if err != nil { + return nil, fmt.Errorf("could not marshal ASN1: %w", err) + } + + return +} + +// DecryptPEMBlock is a drop-in replacement for [x509.DecryptPEMBlock], which +// only supports state-of-the art algorithms such as PBES2. +func DecryptPEMBlock(block *pem.Block, password []byte) ([]byte, error) { + var ( + keyInfo EncryptedPrivateKeyInfo + prf pkix.AlgorithmIdentifier + err error + ) + + if block.Type != "ENCRYPTED PRIVATE KEY" { + return nil, errors.New("key is not a PKCS#8") + } + + _, err = asn1.Unmarshal(block.Bytes, &keyInfo) + if err != nil { + return nil, fmt.Errorf("failed to retrieve private key info: %w", err) + } + + if !keyInfo.EncryptionAlgorithm.Algorithm.Equal(oidPBES2) { + return nil, errors.New("unsupported encryption algorithm: only PBES2 is supported") + } + + if !keyInfo.EncryptionAlgorithm.Params.KeyDerivationFunc.Algorithm.Equal(oidPBKDF2) { + return nil, errors.New("unsupported key derivation algorithm: only PBKDF2 is supported") + } + + prf = keyInfo.EncryptionAlgorithm.Params.KeyDerivationFunc.PBKDF2Params.PRF + if prf.Algorithm != nil && !prf.Algorithm.Equal(oidHMACWithSHA256) { + return nil, errors.New("unsupported pseudo-random function: only HMACWithSHA256 is supported") + } + + keyParams := keyInfo.EncryptionAlgorithm.Params.KeyDerivationFunc.PBKDF2Params + keyHash := sha256.New + + symkey := pbkdf2.Key(password, keyParams.Salt, keyParams.IterationCount, 16, keyHash) + + // We can safely ignore the errors here, because the only error which can + // occur in aes.NewCipher is an invalid key size and the above line makes + // sure we always have a 32 bytes key. + cipherBlock, _ := aes.NewCipher(symkey) + mode := cipher.NewCBCDecrypter(cipherBlock, keyInfo.EncryptionAlgorithm.Params.EncryptionScheme.IV) + mode.CryptBlocks(keyInfo.EncryptedData, keyInfo.EncryptedData) + + return keyInfo.EncryptedData, nil +} diff --git a/pkcs/pkcs5_test.go b/pkcs/pkcs5_test.go new file mode 100644 index 0000000..082eb44 --- /dev/null +++ b/pkcs/pkcs5_test.go @@ -0,0 +1,351 @@ +package pkcs + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/pem" + "io" + "reflect" + "testing" + "testing/iotest" +) + +func TestParsePKCS8PrivateKeyWithPassword(t *testing.T) { + type args struct { + data []byte + password []byte + } + tests := []struct { + name string + args args + wantKey func(*testing.T, crypto.PrivateKey) + wantErr bool + }{ + { + name: "private key with password", + args: args{ + data: []byte( + `-----BEGIN ENCRYPTED PRIVATE KEY----- +MIHqMFUGCSqGSIb3DQEFDTBIMCcGCSqGSIb3DQEFDDAaBAg/ry1F70gEOwICJxAw +CgYIKoZIhvcNAgkwHQYJYIZIAWUDBAECBBAssuVSH48KsMJ6RPl/mG9qBIGQii4G +54TH7t/WrIHgE9xB82RojLdQ8b2WAvjWFepY4RsunHNnDcljEKyFySnqe4f57cRy +3lfGKes6U5ubV5Bi/ffsb5/fApUD93GfIrHSW4yxb4oUKOa30ODwPbwx10sji8Vk +zpW8KFxMcSEgVROGQJFAKVHwbA8dOlOPmewQuh2DXiRqYucncbvxey1flMln +-----END ENCRYPTED PRIVATE KEY-----`), + password: []byte("changeme"), + }, + wantErr: false, + wantKey: func(tt *testing.T, got crypto.PrivateKey) { + if got == nil { + tt.Error("ParseECPrivateKeyFromPEMWithPassword() is nil") + } + }, + }, + { + name: "private key wrong password", + args: args{ + data: []byte( + `-----BEGIN ENCRYPTED PRIVATE KEY----- +MIHqMFUGCSqGSIb3DQEFDTBIMCcGCSqGSIb3DQEFDDAaBAg/ry1F70gEOwICJxAw +CgYIKoZIhvcNAgkwHQYJYIZIAWUDBAECBBAssuVSH48KsMJ6RPl/mG9qBIGQii4G +54TH7t/WrIHgE9xB82RojLdQ8b2WAvjWFepY4RsunHNnDcljEKyFySnqe4f57cRy +3lfGKes6U5ubV5Bi/ffsb5/fApUD93GfIrHSW4yxb4oUKOa30ODwPbwx10sji8Vk +zpW8KFxMcSEgVROGQJFAKVHwbA8dOlOPmewQuh2DXiRqYucncbvxey1flMln +-----END ENCRYPTED PRIVATE KEY-----`), + password: []byte("nottest"), + }, + wantErr: true, + }, + { + name: "not a private key", + args: args{ + data: []byte( + `-----BEGIN ENCRYPTED PRIVATE KEY----- +THIS IS NOT A PRIVATE KEY +-----END ENCRYPTED PRIVATE KEY-----`), + password: []byte("test"), + }, + wantErr: true, + }, + { + name: "not PEM", + args: args{ + data: []byte( + `NOTPEM`), + password: []byte("test"), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotKey, err := ParsePKCS5PrivateKeyWithPassword(tt.args.data, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("ParsePKCS5PrivateKeyWithPassword() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantKey != nil { + tt.wantKey(t, gotKey) + } + }) + } +} + +func TestMarshalPKCS5PrivateKeyWithPassword(t *testing.T) { + pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + type args struct { + key *ecdsa.PrivateKey + password []byte + } + tests := []struct { + name string + args args + wantData func(*testing.T, []byte) + wantErr bool + }{ + { + name: "Marshal EC key", + args: args{ + key: pk, + password: []byte("test"), + }, + wantData: func(tt *testing.T, data []byte) { + if len(data) == 0 { + tt.Error("MarshalPKCS5PrivateKeyWithPassword() is empty") + } + }, + }, + { + name: "Error while marshalling EC key", + args: args{ + key: &ecdsa.PrivateKey{}, + password: []byte("test"), + }, + wantErr: true, + wantData: nil, + }, + { + name: "Empty password", + args: args{ + key: pk, + password: []byte{}, + }, + wantErr: true, + wantData: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotData, err := MarshalPKCS5PrivateKeyWithPassword(tt.args.key, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalPKCS5PrivateKeyWithPassword() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantData != nil { + tt.wantData(t, gotData) + } + }) + } +} + +func TestDecryptPEMBlock(t *testing.T) { + type args struct { + block *pem.Block + password []byte + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + name: "wrong type", + args: args{ + block: &pem.Block{ + Type: "SOMETYPE", + }, + }, + wantErr: true, + }, + { + name: "not ASN1", + args: args{ + block: &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: []byte{1, 2, 3}, + }, + }, + wantErr: true, + }, + { + name: "wrong encryption algorithm", + args: args{ + block: &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: func() []byte { + b, err := asn1.Marshal(EncryptedPrivateKeyInfo{ + EncryptionAlgorithm: EncryptionAlgorithmIdentifier{ + Algorithm: asn1.ObjectIdentifier{0, 0}, + Params: PBES2Params{ + KeyDerivationFunc: KeyDerivationFunc{ + Algorithm: oidPBKDF2, + }, + EncryptionScheme: EncryptionScheme{ + EncryptionAlgorithm: oidAES128CBC, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + return b + }(), + }, + }, + wantErr: true, + }, + { + name: "wrong key derivation algorithm", + args: args{ + block: &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: func() []byte { + b, err := asn1.Marshal(EncryptedPrivateKeyInfo{ + EncryptionAlgorithm: EncryptionAlgorithmIdentifier{ + Algorithm: oidPBES2, + Params: PBES2Params{ + KeyDerivationFunc: KeyDerivationFunc{ + Algorithm: asn1.ObjectIdentifier{0, 0}, + }, + EncryptionScheme: EncryptionScheme{ + EncryptionAlgorithm: oidAES128CBC, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + return b + }(), + }, + }, + wantErr: true, + }, + { + name: "wrong PRF", + args: args{ + block: &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: func() []byte { + b, err := asn1.Marshal(EncryptedPrivateKeyInfo{ + EncryptionAlgorithm: EncryptionAlgorithmIdentifier{ + Algorithm: oidPBES2, + Params: PBES2Params{ + KeyDerivationFunc: KeyDerivationFunc{ + Algorithm: oidPBKDF2, + PBKDF2Params: PBKDF2Params{ + PRF: pkix.AlgorithmIdentifier{ + Algorithm: asn1.ObjectIdentifier{0, 0}, + }, + }, + }, + EncryptionScheme: EncryptionScheme{ + EncryptionAlgorithm: oidAES128CBC, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + return b + }(), + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DecryptPEMBlock(tt.args.block, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("DecryptPEMBlock() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("DecryptPEMBlock() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEncryptPEMBlock(t *testing.T) { + // manipulate OID to provoke an error + oidPBKDF2 = asn1.ObjectIdentifier{0} + + type args struct { + rand io.Reader + data []byte + password []byte + } + tests := []struct { + name string + args args + wantBlock *pem.Block + wantErr bool + }{ + { + name: "invalid rand", + args: args{ + rand: iotest.ErrReader(io.EOF), + password: []byte{1}, + }, + wantErr: true, + }, + { + name: "invalid rand", + args: args{ + rand: bytes.NewReader(make([]byte, 8)), + password: []byte{1}, + }, + wantErr: true, + }, + { + name: "invalid", + args: args{ + rand: bytes.NewReader(make([]byte, 16)), + password: []byte{1}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotBlock, err := EncryptPEMBlock(tt.args.rand, tt.args.data, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("EncryptPEMBlock() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotBlock, tt.wantBlock) { + t.Errorf("EncryptPEMBlock() = %v, want %v", gotBlock, tt.wantBlock) + } + }) + } +} diff --git a/server.go b/server.go index bc33b47..fa7d2b9 100644 --- a/server.go +++ b/server.go @@ -270,7 +270,6 @@ func (srv *AuthorizationServer) doRefreshTokenFlow(w http.ResponseWriter, r *htt return srv.PublicKeys()[int(kid)], nil }) if err != nil { - fmt.Printf("%+v", err) Error(w, ErrorInvalidGrant, http.StatusBadRequest) return } diff --git a/storage/key_loader.go b/storage/key_loader.go new file mode 100644 index 0000000..06b7443 --- /dev/null +++ b/storage/key_loader.go @@ -0,0 +1,146 @@ +package storage + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + + "github.com/oxisto/oauth2go/pkcs" +) + +type keyLoader struct { + path string + password string + saveOnCreate bool +} + +var ErrNotECPrivateKey = errors.New("key is not a valid EC private key") + +// LoadSigningKeys implements a singing keys func for our internal authorization +// server. Please note that [path] already needs to be an expanded path, e.g., +// references to a home directory (~) already need to be expanded before-hand. +func LoadSigningKeys(path string, password string, saveOnCreate bool) map[int]*ecdsa.PrivateKey { + // create a key loader with our arguments + loader := keyLoader{ + path: path, + password: password, + saveOnCreate: saveOnCreate, + } + + return map[int]*ecdsa.PrivateKey{ + 0: loader.LoadKey(), + } +} + +func (l *keyLoader) LoadKey() (key *ecdsa.PrivateKey) { + var ( + err error + ) + + // Try to load the key from the given path + key, err = loadKeyFromFile(l.path, []byte(l.password)) + if err != nil { + key = l.recoverFromLoadKeyError(err) + } + + return +} + +// recoverFromLoadKeyError tries to recover from an error during key loading. +func (l *keyLoader) recoverFromLoadKeyError(err error) (key *ecdsa.PrivateKey) { + // In any case, create a new temporary private key + key, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + if errors.Is(err, os.ErrNotExist) && l.saveOnCreate { + slog.Info("Private key does not exist at the location yet. We will create a new one") + + // Also make sure that the containing folder exists + err = ensureFolderExistence(filepath.Dir(l.path)) + // Error while error handling, meh + if err != nil { + goto savingerr + } + + // Also save the key in this case, so we can load it next time + err = saveKeyToFile(key, l.path, l.password) + + savingerr: + // Error while error handling, meh + if err != nil { + slog.Error("Error while saving the new private key", "err", err) + } + } else if err != nil { + slog.Error("Could not load key from file, continuing with a temporary key", "err", err) + } + + return key +} + +// loadKeyFromFile loads an ecdsa.PrivateKey from a path. The key must in PEM +// format and protected by a password using PKCS#8 with PBES2. +// +// Note: This only supports ECDSA keys (for now) since we only support ECDSA +// keys in the authorization server (a limitation in the JWKS implementation). +// However, the underyling functions such as [ParsePKCS8PrivateKeyWithPassword] +// support all kind of private keys, so we might also support all private keys +// in the future here. +func loadKeyFromFile(path string, password []byte) (key *ecdsa.PrivateKey, err error) { + var ( + k crypto.PrivateKey + ok bool + ) + + // Check, if we already have a persisted private key + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("error while reading key: %w", err) + } + + k, err = pkcs.ParsePKCS5PrivateKeyWithPassword(data, password) + if err != nil { + return nil, fmt.Errorf("error while parsing private key: %w", err) + } + + key, ok = k.(*ecdsa.PrivateKey) + if !ok { + return nil, ErrNotECPrivateKey + } + + return key, nil +} + +// saveKeyToFile saves an [crypto.PrivateKey] to a path. The key will be saved +// in PEM format and protected by a password using PKCS#8 with PBES2. +func saveKeyToFile(key crypto.PrivateKey, path string, password string) (err error) { + data, err := pkcs.MarshalPKCS5PrivateKeyWithPassword(key, []byte(password)) + if err != nil { + return fmt.Errorf("error while marshalling private key: %w", err) + } + + err = os.WriteFile(path, data, 0600) + if err != nil { + return fmt.Errorf("error while writing file content: %w", err) + } + + return nil +} + +// ensureFolderExistence ensures that the folder exists. +func ensureFolderExistence(path string) (err error) { + // Create the directory, if it not exists + _, err = os.Stat(path) + if errors.Is(err, os.ErrNotExist) { + err = os.Mkdir(path, os.ModePerm) + if err != nil { + return err + } + } + + return +} diff --git a/storage/key_loader_test.go b/storage/key_loader_test.go new file mode 100644 index 0000000..246735f --- /dev/null +++ b/storage/key_loader_test.go @@ -0,0 +1,323 @@ +package storage + +import ( + "crypto/ecdsa" + "io" + "os" + "reflect" + "testing" +) + +func Test_keyLoader_recoverFromLoadKeyError(t *testing.T) { + var tmpFile, _ = os.CreateTemp("", "private.key") + // Close it immediately , since we want to write to it + tmpFile.Close() + + defer func() { + os.Remove(tmpFile.Name()) + }() + + type fields struct { + path string + password string + saveOnCreate bool + } + type args struct { + err error + defaultPath bool + } + tests := []struct { + name string + fields fields + args args + wantKey func(*testing.T, *ecdsa.PrivateKey) + }{ + { + name: "Could not load key from custom path", + fields: fields{ + saveOnCreate: false, + path: "doesnotexist", + password: "test", + }, + args: args{ + err: os.ErrNotExist, + defaultPath: false, + }, + wantKey: func(tt *testing.T, got *ecdsa.PrivateKey) { + if got == nil { + tt.Error("keyLoader.recoverFromLoadKeyError() is nil") + } + }, + }, + { + name: "Could not load key from default path and save it", + fields: fields{ + saveOnCreate: true, + path: tmpFile.Name(), + password: "test", + }, + args: args{ + err: os.ErrNotExist, + defaultPath: true, + }, + wantKey: func(tt *testing.T, got *ecdsa.PrivateKey) { + if got == nil { + tt.Error("keyLoader.recoverFromLoadKeyError() is nil") + } + + f, _ := os.OpenFile(tmpFile.Name(), os.O_RDONLY, 0600) + // Our tmp file should also contain something now + data, _ := io.ReadAll(f) + + if len(data) == 0 { + tt.Error("keyLoader.recoverFromLoadKeyError() did not write key on file") + } + }, + }, + { + name: "error while recovering", + fields: fields{ + saveOnCreate: true, + path: "/youwillnotcreatethis/file", + password: "test", + }, + args: args{ + err: os.ErrNotExist, + }, + wantKey: func(tt *testing.T, got *ecdsa.PrivateKey) { + if got == nil { + tt.Error("keyLoader.recoverFromLoadKeyError() is nil") + } + }, + }, + { + name: "error while recovering", + fields: fields{ + saveOnCreate: true, + path: "/youwillnotcreatethis", + password: "test", + }, + args: args{ + err: os.ErrNotExist, + }, + wantKey: func(tt *testing.T, got *ecdsa.PrivateKey) { + if got == nil { + tt.Error("keyLoader.recoverFromLoadKeyError() is nil") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &keyLoader{ + path: tt.fields.path, + password: tt.fields.password, + saveOnCreate: tt.fields.saveOnCreate, + } + gotKey := l.recoverFromLoadKeyError(tt.args.err) + + if tt.wantKey != nil { + tt.wantKey(t, gotKey) + } + }) + } +} + +func Test_keyLoader_LoadKey(t *testing.T) { + type fields struct { + path string + password string + saveOnCreate bool + homeDirFunc func() (string, error) + } + tests := []struct { + name string + fields fields + wantKey func(*testing.T, *ecdsa.PrivateKey) + }{ + { + name: "happy path", + fields: fields{ + path: "testdata/ecdsa.pem", + password: "changeme", + homeDirFunc: os.UserHomeDir, + }, + wantKey: func(tt *testing.T, pk *ecdsa.PrivateKey) { + if pk == nil { + tt.Fatal("keyLoader.LoadKey() is nil") + } + if pk.X == nil { + tt.Fatal("keyLoader.LoadKey(): X is nil") + } + if pk.X.String() != "28234181521930490715768413662959502200866653406621902194676464978939371342802" { + tt.Fatal("keyLoader.LoadKey(): X is wrong") + } + }, + }, + { + name: "recovered path", + fields: fields{ + homeDirFunc: os.UserHomeDir, + }, + wantKey: func(tt *testing.T, pk *ecdsa.PrivateKey) { + if pk == nil { + tt.Error("keyLoader.LoadKey() is nil") + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &keyLoader{ + path: tt.fields.path, + password: tt.fields.password, + saveOnCreate: tt.fields.saveOnCreate, + } + gotKey := l.LoadKey() + tt.wantKey(t, gotKey) + }) + } +} + +func TestLoadSigningKeys(t *testing.T) { + type args struct { + path string + password string + saveOnCreate bool + } + tests := []struct { + name string + args args + want func(*testing.T, map[int]*ecdsa.PrivateKey) + }{ + { + name: "happy path", + args: args{ + path: "testdata/ecdsa.pem", + password: "changeme", + }, + want: func(tt *testing.T, m map[int]*ecdsa.PrivateKey) { + if m[0].X == nil { + tt.Fatal("keyLoader.LoadKey(): X is nil") + } + if m[0].X.String() != "28234181521930490715768413662959502200866653406621902194676464978939371342802" { + tt.Fatal("keyLoader.LoadKey(): X is wrong") + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := LoadSigningKeys(tt.args.path, tt.args.password, tt.args.saveOnCreate) + + tt.want(t, got) + }) + } +} + +func Test_keyLoader_ensureFolderExistence(t *testing.T) { + type args struct { + path string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "mkdir fail", + args: args{ + path: "/thisshouldnotwork/file", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ensureFolderExistence(tt.args.path); (err != nil) != tt.wantErr { + t.Errorf("keyLoader.ensureFolderExistence() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_loadKeyFromFile(t *testing.T) { + type args struct { + path string + password []byte + } + tests := []struct { + name string + args args + wantKey *ecdsa.PrivateKey + wantErr bool + }{ + { + name: "not a PEM file", + args: args{ + path: "testdata/test.notkey", + }, + wantErr: true, + }, + { + name: "not an ECDSA key", + args: args{ + path: "testdata/rsa.pem", + password: []byte("test"), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotKey, err := loadKeyFromFile(tt.args.path, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("loadKeyFromFile() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotKey, tt.wantKey) { + t.Errorf("loadKeyFromFile() = %v, want %v", gotKey, tt.wantKey) + } + }) + } +} + +func Test_saveKeyToFile(t *testing.T) { + var tmpFile, _ = os.CreateTemp("", "private.key") + // Close it immediately , since we want to write to it + tmpFile.Close() + + defer func() { + os.Remove(tmpFile.Name()) + }() + + type args struct { + key *ecdsa.PrivateKey + path string + password string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "not a valid key", + args: args{ + path: tmpFile.Name(), + key: &ecdsa.PrivateKey{}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := saveKeyToFile(tt.args.key, tt.args.path, tt.args.password); (err != nil) != tt.wantErr { + t.Errorf("saveKeyToFile() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/storage/testdata/ecdsa.pem b/storage/testdata/ecdsa.pem new file mode 100644 index 0000000..ffde790 --- /dev/null +++ b/storage/testdata/ecdsa.pem @@ -0,0 +1,7 @@ +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIHqMFUGCSqGSIb3DQEFDTBIMCcGCSqGSIb3DQEFDDAaBAgXblPIt3CqMAICJxAw +CgYIKoZIhvcNAgkwHQYJYIZIAWUDBAECBBDnogUoVkunUgAPpfIc5s5oBIGQth6F +75CTgtAl1FfPko8NJ8JRnv2PfjZV8xhow4s5TE32AMIxoZJyQNfiWy+OS5XLULl3 +/m/Ubg+dECz18rdXHum1VSJ8oxhe1IIvJ6KNHNvFx/8cTkasEImKYGPkf91qY2jo +YLpAxTNlH+F1JRymNZ3VPU7v2R9mZ4AgcO8gRFJSrtV7Cz4tXoJ12mY+gzbQ +-----END ENCRYPTED PRIVATE KEY----- diff --git a/storage/testdata/rsa.pem b/storage/testdata/rsa.pem new file mode 100644 index 0000000..3cc2532 --- /dev/null +++ b/storage/testdata/rsa.pem @@ -0,0 +1,30 @@ +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIFNTBfBgkqhkiG9w0BBQ0wUjAxBgkqhkiG9w0BBQwwJAQQ694p4vcYBePAZ64v +NK+fqQICCAAwDAYIKoZIhvcNAgkFADAdBglghkgBZQMEAQIEECH6SjP+31K/bMPS +EZIijS0EggTQJm6UsCZ3V27LBKg+eZLxLybBFWHaKsEE8O+CDFJqsw/ewdUPynae +CIbIEbYpaKWPqalk/sWHgbmw0tCD3Y6hNggVU5xVjVkHxVI5SkCAmTs5qoMaPQq/ +dCVgyfkybL5htJPgnsNbAt817GQLZo40iN5S6R+ZFzyyT53bpkhsXibHwy8mCJ8Q +Ezq7cn6Zou93WhqOTl5PalxkGcX5fa7TsTx6yK2vlbjOMtZBIEcO4gjLqMNNAAXs +Covon1o8MhlNnQkZLY2r5mPXJjrxjcyKXFwcCaxfOTLJ55tLQ+SbB510PULjwcmf +/jQEiRHRPcxewTj85YgoT1dRTHQEfjtR2s43smnMGQWLn2pZ8vfA8JEjSZVgbSP2 +GVncOfkwcIV+I3IqWxmFwTTMRn5xTO9Bm+PcmQ8hs7vd+MTdBJ/wHeYn+cl7w3Ki ++7zJcWFx/fYu3v+XDYm2vHcP0r2SeW9SyJJNZ0BGoqgu+Xz3EjcqtRn/4HR/We8g +9o+S/wx7PTUjHvAAqKdwFkEYv0g+QZOJUJhg5K9Kh9jbZgpwYe+L7XirNKyJjXSp +32/rLvfbQoObNfIdacAVeKB/cixnJEJ3CN8IUUGTg/Yvj1i8t0kIVyirJTw2Puvp +hTGiir5LfyziGYJnEwBzATsjHV/gO2VUzNd7anYUtEJMHYf96gAb+Wc01U4nhoKv +Ls+13Ny3rGNBjASJaKMgQ2CB+lyqjI0iUWE1fWfDCZ2s4qd8ezYMxvJpRahmGb71 +A0dtq9CoeC4a2GXbiWjn24G18tdprNJWjTue5VPJjCDMZbxUjC8lYOWuCX+bGjzf +9gY6o3cSKAVkU+yNrNVtAwGJoPRfVUP38T7FrtM09OWl6qEvM4ssFIn+ezWdvLdE +OHOqycQP3faI/8qDlOHAPWDMV6v/3PxdoJDu23LYBvuVZvaCv0k4CymBDGdo+3X+ +iaL4CkVgR6zE1fSXX6a8TFSjiIbkBFpbhvwbcH8+GUISAdiBOEvDr20DAzCVf0g9 +RxbO4le6h2PF90F4Nib9aR722aygsb5OND/q3j+7AxHxDrxDTjQIStV45qXKt0S0 +eONSOdJMslqxvQ2WhCf/Fcu1aVrKIHNKyFZODAa+1YhnOGxz7X83njij0d6Mmoh6 +JxT9rkIiiSvROrNHVdDrIcvRCANAnvDwEcA/Nc+ns/+yoBMWXla1roG+IPUUHoXZ +9BuhGx+TJbyg4c+l0MfiP/ht1AtH3U9HCc7EcSh4MCEYKvd5MsM361QVZbYvy5LW +VqwqDWvkdm+2fgj+0j416umUWhvbi3Bsm73I4CcsXruWgTXiId4bMtJ2O3oD6cDN +yz6YT6+OD3XSmuEvtKOETueF2Ef3w/TF3FgDsU6oQXD9nmp+J9dL58TPgpvqk0dv +y6R1FaWo16+8Yh/lVyfdJEt4CfRNNeOvnlyLZWbuJMkOJnWjCObIGpfb113NG5wf +34V4lw/SWBTAepaD6c4UH4y66W3x6cDorKuQiqC2oZwhN4YwUodxwWvwzSFUxhg6 +u3xHRqolwNycfgTUppTOcDsGg40OmkwvfniZ7UIhUaLayFN7hPofuUEvRJRga4tR +grzkd7B1SRxVJa1A+9O9Kko928C9BEt+QgML+H/Te0BRdhQQtL8s+eU= +-----END ENCRYPTED PRIVATE KEY----- diff --git a/storage/testdata/test.notkey b/storage/testdata/test.notkey new file mode 100644 index 0000000..fdab659 --- /dev/null +++ b/storage/testdata/test.notkey @@ -0,0 +1 @@ +this is not a key \ No newline at end of file