diff --git a/storage/key_loader.go b/storage/key_loader.go new file mode 100644 index 0000000..7599dac --- /dev/null +++ b/storage/key_loader.go @@ -0,0 +1,211 @@ +package storage + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "errors" + "fmt" + "log/slog" + "os" + "os/user" + "path/filepath" + "strings" +) + +const ( + // DefaultPrivateKeySaveOnCreate specifies whether a created private key + // will be saved. This is useful to turn off in unit tests, where we only + // want a temporary key. + DefaultPrivateKeySaveOnCreate = true + + // DefaultPrivateKeyPassword is the default password to protect the private + // key. + DefaultPrivateKeyPassword = "changeme" + + // DefaultPrivateKeyPath is the default path for the private key. + DefaultPrivateKeyPath = DefaultConfigDirectory + "/private.key" + + // DefaultConfigDirectory is the default path for the oauth2go + // configuration, such as keys. + DefaultConfigDirectory = "~/.oauth2go" +) + +type keyLoader struct { + path string + password string + saveOnCreate bool +} + +// LoadSigningKeys implements a singing keys func for our internal authorization server +func LoadSigningKeys(path string, password string, saveOnCreate bool) map[int]*ecdsa.PrivateKey { + // create a key loader with our arguments + loader := keyLoader{ + path: path, + password: password, + saveOnCreate: saveOnCreate, + } + + return map[int]*ecdsa.PrivateKey{ + 0: loader.LoadKey(), + } +} + +func (l *keyLoader) LoadKey() (key *ecdsa.PrivateKey) { + var ( + err error + ) + + // Try to load the key from the given path + key, err = loadKeyFromFile(l.path, []byte(l.password)) + if err != nil { + key = l.recoverFromLoadApiKeyError(err, l.path == DefaultPrivateKeyPath) + } + + return +} + +// recoverFromLoadApiKeyError tries to recover from an error during key loading. +// We treat different errors differently. For example if the path is the default +// path and the error is [os.ErrNotExist], this could be just the first start of +// Clouditor. So we only treat this as an information that we will create a new +// key, which we will save, based on the config. +// +// If the user specifies a custom path and this one does not exist, we will +// report an error here. +func (l *keyLoader) recoverFromLoadApiKeyError(err error, defaultPath bool) (key *ecdsa.PrivateKey) { + // In any case, create a new temporary API key + key, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + if defaultPath && errors.Is(err, os.ErrNotExist) { + slog.Info("API key does not exist at the default location yet. We will create a new one") + + if l.saveOnCreate { + // Also make sure that default config path exists + err = ensureConfigFolderExistence() + // Error while error handling, meh + if err != nil { + return + } + + // Also save the key in this case, so we can load it next time + err = saveKeyToFile(key, l.path, l.password) + + // Error while error handling, meh + if err != nil { + slog.Error("Error while saving the new API key", "err", err) + } + } + } else if err != nil { + slog.Error("Could not load key from file, continuing with a temporary key", "err", err) + } + + return key +} + +// loadKeyFromFile loads an ecdsa.PrivateKey from a path. The key must in PEM format and protected by +// a password using PKCS#8 with PBES2. +func loadKeyFromFile(path string, password []byte) (key *ecdsa.PrivateKey, err error) { + var ( + keyFile string + ) + + keyFile, err = expandPath(path) + if err != nil { + return nil, fmt.Errorf("error while expanding path: %w", err) + } + + if _, err = os.Stat(keyFile); os.IsNotExist(err) { + return nil, fmt.Errorf("file does not exist (yet): %w", err) + } + + // Check, if we already have a persisted API key + data, err := os.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("error while reading key: %w", err) + } + + key, err = ParseECPrivateKeyFromPEMWithPassword(data, password) + if err != nil { + return nil, fmt.Errorf("error while parsing private key: %w", err) + } + + return key, nil +} + +// saveKeyToFile saves an ecdsa.PrivateKey to a path. The key will be saved in +// PEM format and protected by a password using PKCS#8 with PBES2. +func saveKeyToFile(apiKey *ecdsa.PrivateKey, keyPath string, password string) (err error) { + keyPath, err = expandPath(keyPath) + if err != nil { + return fmt.Errorf("error while expanding path: %w", err) + } + + // Check, if we already have a persisted API key + f, err := os.OpenFile(keyPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + return fmt.Errorf("error while opening the file: %w", err) + } + defer func() { + _ = f.Close() + }() + + data, err := MarshalECPrivateKeyWithPassword(apiKey, []byte(password)) + if err != nil { + return fmt.Errorf("error while marshalling private key: %w", err) + } + + _, err = f.Write(data) + if err != nil { + return fmt.Errorf("error while writing file content: %w", err) + } + + return nil +} + +// expandPath expands a path that possible contains a tilde (~) character into +// the home directory of the user +func expandPath(path string) (out string, err error) { + var ( + u *user.User + ) + + // Fetch the current user home directory + u, err = user.Current() + if err != nil { + return path, fmt.Errorf("could not find retrieve current user: %w", err) + } + + if path == "~" { + return u.HomeDir, nil + } else if strings.HasPrefix(path, "~") { + // We only allow ~ at the beginning of the path + return filepath.Join(u.HomeDir, path[2:]), nil + } + + return path, nil +} + +// ensureConfigesFolderExistence ensures that the config folder exists. +func ensureConfigFolderExistence() (err error) { + var configPath string + + // Expand the config directory, if it contains any ~ characters. + configPath, err = expandPath(DefaultConfigDirectory) + if err != nil { + // Directly return the error here, no need for additional wrapping + return err + } + + // Create the directory, if it not exists + _, err = os.Stat(configPath) + if errors.Is(err, os.ErrNotExist) { + err = os.Mkdir(configPath, os.ModePerm) + if err != nil { + // Directly return the error here, no need for additional wrapping + return err + } + } + + return +} diff --git a/storage/key_loader_test.go b/storage/key_loader_test.go new file mode 100644 index 0000000..adc7d0f --- /dev/null +++ b/storage/key_loader_test.go @@ -0,0 +1,92 @@ +package storage + +import ( + "crypto/ecdsa" + "io" + "os" + "testing" +) + +func Test_keyLoader_recoverFromLoadApiKeyError(t *testing.T) { + var tmpFile, _ = os.CreateTemp("", "api.key") + // Close it immediately , since we want to write to it + tmpFile.Close() + + defer func() { + os.Remove(tmpFile.Name()) + }() + + type fields struct { + path string + password string + saveOnCreate bool + } + type args struct { + err error + defaultPath bool + } + tests := []struct { + name string + fields fields + args args + wantKey func(*testing.T, *ecdsa.PrivateKey) + }{ + { + name: "Could not load key from custom path", + fields: fields{ + saveOnCreate: false, + path: "doesnotexist", + password: "test", + }, + args: args{ + err: os.ErrNotExist, + defaultPath: false, + }, + wantKey: func(tt *testing.T, got *ecdsa.PrivateKey) { + if got == nil { + tt.Error("keyLoader.recoverFromLoadApiKeyError() is nil") + } + }, + }, + { + name: "Could not load key from default path and save it", + fields: fields{ + saveOnCreate: true, + path: tmpFile.Name(), + password: "test", + }, + args: args{ + err: os.ErrNotExist, + defaultPath: true, + }, + wantKey: func(tt *testing.T, got *ecdsa.PrivateKey) { + if got == nil { + tt.Error("keyLoader.recoverFromLoadApiKeyError() is nil") + } + + f, _ := os.OpenFile(tmpFile.Name(), os.O_RDONLY, 0600) + // Our tmp file should also contain something now + data, _ := io.ReadAll(f) + + if len(data) == 0 { + tt.Error("keyLoader.recoverFromLoadApiKeyError() did not write key on file") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &keyLoader{ + path: tt.fields.path, + password: tt.fields.password, + saveOnCreate: tt.fields.saveOnCreate, + } + gotKey := l.recoverFromLoadApiKeyError(tt.args.err, tt.args.defaultPath) + + if tt.wantKey != nil { + tt.wantKey(t, gotKey) + } + }) + } +} diff --git a/storage/pem.go b/storage/pem.go new file mode 100644 index 0000000..b3881b4 --- /dev/null +++ b/storage/pem.go @@ -0,0 +1,239 @@ +package storage + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "encoding/asn1" + "encoding/pem" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/pbkdf2" +) + +var ( + oidAESCBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 16, 12} + oidHMACWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 9} + oidPBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} + oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} +) + +var ErrNotECPrivateKey = errors.New("key is not a valid EC private key") + +// PBKDF2Params are parameters for PBKDF2. See +// https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.2. +type PBKDF2Params struct { + Salt []byte + IterationCount int + PRF asn1.ObjectIdentifier `asn1:"optional"` +} + +// KeyDerivationFunc is part of PBES2 and specify the key derivation function. +// See https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.4. +type KeyDerivationFunc struct { + Algorithm asn1.ObjectIdentifier + PBKDF2Params PBKDF2Params +} + +// EncryptionScheme is part of PBES2 and specifies the encryption algorithm. See +// https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.4. +type EncryptionScheme struct { + EncryptionAlgorithm asn1.ObjectIdentifier + IV []byte +} + +// PBES2Params are parameters for PBES2. See +// https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.4. +type PBES2Params struct { + KeyDerivationFunc KeyDerivationFunc + EncryptionScheme EncryptionScheme +} + +// EncryptionAlgorithmIdentifier is the identifier for the encryption algorithm. +// See https://datatracker.ietf.org/doc/html/rfc5958#section-3. +type EncryptionAlgorithmIdentifier struct { + Algorithm asn1.ObjectIdentifier + Params PBES2Params +} + +// EncryptedPrivateKeyInfo contains meta-info about the encrypted private key. +// See https://datatracker.ietf.org/doc/html/rfc5958#section-3. +type EncryptedPrivateKeyInfo struct { + EncryptionAlgorithm EncryptionAlgorithmIdentifier + EncryptedData []byte +} + +// MarshalECPrivateKeyWithPassword marshals an ECDSA private key protected with +// a password according to PKCS#8 into a byte array +func MarshalECPrivateKeyWithPassword(key *ecdsa.PrivateKey, password []byte) (data []byte, err error) { + var decrypted []byte + decrypted, err = x509.MarshalPKCS8PrivateKey(key) + if err != nil { + // Directly return error here, because we are basically a wrapper around + // x509.MarshalPKCS8PrivateKey and we want our errors to be similar + return nil, err + } + + block, err := EncryptPEMBlock(rand.Reader, decrypted, password) + if err != nil { + return nil, fmt.Errorf("could not encrypt PEM block: %w", err) + } + + return pem.EncodeToMemory(block), nil +} + +// ParseECPrivateKeyFromPEMWithPassword ready an ECDSA private key protected +// with a password according to PKCS#8 from a byte array. +func ParseECPrivateKeyFromPEMWithPassword(data []byte, password []byte) (key *ecdsa.PrivateKey, err error) { + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(data); block == nil { + return nil, errors.New("could not decode PEM") + } + + var decrypted []byte + if decrypted, err = DecryptPEMBlock(block, password); err != nil { + return nil, fmt.Errorf("could not decrypt PEM block: %w", err) + } + + parsedKey, err := x509.ParsePKCS8PrivateKey(decrypted) + if err != nil { + // Directly return error here, because we are basically a wrapper around + // x509.ParsePKCS8PrivateKey and we want our errors to be similar + return nil, err + } + + var ok bool + if key, ok = parsedKey.(*ecdsa.PrivateKey); !ok { + return nil, ErrNotECPrivateKey + } + + return +} + +// EncryptPEMBlock encrypts a private key contained in data into a PEM block +// according to PKCS#8. +func EncryptPEMBlock(rand io.Reader, data, password []byte) (block *pem.Block, err error) { + var salt = make([]byte, 8) + if _, err = rand.Read(salt); err != nil { + return nil, fmt.Errorf("error creating salt: %w", err) + } + + var iv = make([]byte, 16) + if _, err = rand.Read(iv); err != nil { + return nil, fmt.Errorf("error creating IV: %w", err) + } + + var pad = 16 - len(data)%16 + + // Build EncryptedPrivateKeyInfo + keyInfo := EncryptedPrivateKeyInfo{ + EncryptionAlgorithm: EncryptionAlgorithmIdentifier{ + Algorithm: oidPBES2, + Params: PBES2Params{ + KeyDerivationFunc: KeyDerivationFunc{ + Algorithm: oidPBKDF2, + PBKDF2Params: PBKDF2Params{ + IterationCount: 1000, + Salt: salt, + PRF: oidHMACWithSHA256, + }, + }, + EncryptionScheme: EncryptionScheme{ + EncryptionAlgorithm: oidAESCBC, + IV: iv, + }, + }, + }, + EncryptedData: make([]byte, len(data), len(data)+pad), // We will encrypt this later + } + + // Derive key using PBKDF2 + key := pbkdf2.Key( + password, + salt, + keyInfo.EncryptionAlgorithm.Params.KeyDerivationFunc.PBKDF2Params.IterationCount, + 32, + sha256.New, + ) + + // Set up symmetric encryption of our block + cipherBlock, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("could not create AES cipher mode: %w", err) + } + + mode := cipher.NewCBCEncrypter(cipherBlock, keyInfo.EncryptionAlgorithm.Params.EncryptionScheme.IV) + + copy(keyInfo.EncryptedData, data) + for i := 0; i < pad; i++ { + keyInfo.EncryptedData = append(keyInfo.EncryptedData, byte(pad)) + } + + mode.CryptBlocks(keyInfo.EncryptedData, keyInfo.EncryptedData) + + block = &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Headers: make(map[string]string), + } + + // Marshal key info into ASN1 format, which is the payload of our PEM block + block.Bytes, err = asn1.Marshal(keyInfo) + if err != nil { + return nil, fmt.Errorf("could not marshal ASN1: %w", err) + } + + return +} + +// DecryptPEMBlock is a drop-in replacement for x509.DecryptPEMBlock which only +// supports state-of-the art algorithms such as PBES2. +func DecryptPEMBlock(block *pem.Block, password []byte) ([]byte, error) { + var ( + keyInfo EncryptedPrivateKeyInfo + prf asn1.ObjectIdentifier + err error + ) + + if block.Type != "ENCRYPTED PRIVATE KEY" { + return nil, errors.New("key is not a PKCS#8") + } + + _, err = asn1.Unmarshal(block.Bytes, &keyInfo) + if err != nil { + return nil, fmt.Errorf("failed to retrieve private key info: %w", err) + } + + if !keyInfo.EncryptionAlgorithm.Algorithm.Equal(oidPBES2) { + return nil, errors.New("unsupported encryption algorithm: only PBES2 is supported") + } + + if !keyInfo.EncryptionAlgorithm.Params.KeyDerivationFunc.Algorithm.Equal(oidPBKDF2) { + return nil, errors.New("unsupported key derivation algorithm: only PBKDF2 is supported") + } + + prf = keyInfo.EncryptionAlgorithm.Params.KeyDerivationFunc.PBKDF2Params.PRF + + if prf != nil && !prf.Equal(oidHMACWithSHA256) { + return nil, errors.New("unsupported pseudo-random function: only HMACWithSHA256 is supported") + } + + keyParams := keyInfo.EncryptionAlgorithm.Params.KeyDerivationFunc.PBKDF2Params + keyHash := sha256.New + + symkey := pbkdf2.Key(password, keyParams.Salt, keyParams.IterationCount, 32, keyHash) + cipherBlock, err := aes.NewCipher(symkey) + if err != nil { + return nil, fmt.Errorf("could not create AES cipher mode: %w", err) + } + + mode := cipher.NewCBCDecrypter(cipherBlock, keyInfo.EncryptionAlgorithm.Params.EncryptionScheme.IV) + mode.CryptBlocks(keyInfo.EncryptedData, keyInfo.EncryptedData) + + return keyInfo.EncryptedData, nil +} diff --git a/storage/pem_test.go b/storage/pem_test.go new file mode 100644 index 0000000..decb6cd --- /dev/null +++ b/storage/pem_test.go @@ -0,0 +1,188 @@ +package storage + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/asn1" + "encoding/pem" + "reflect" + "testing" +) + +func TestParseECPrivateKeyFromPEMWithPassword(t *testing.T) { + type args struct { + data []byte + password []byte + } + tests := []struct { + name string + args args + wantKey func(*testing.T, *ecdsa.PrivateKey) + wantErr bool + }{ + { + name: "Private key with password", + args: args{ + data: []byte( + `-----BEGIN ENCRYPTED PRIVATE KEY----- +MIHsMFcGCSqGSIb3DQEFDTBKMCkGCSqGSIb3DQEFDDAcBAgTz/KWaEQ7xwICCAAw +DAYIKoZIhvcNAgkFADAdBglghkgBZQMEASoEEEoMbQeGZBq+RJGRyY2N8PwEgZAY +U36vBRn5HB8zNSic75MfpGXWRVXki1qm29G/DU+E68hksvfbJlqqpL12fQ5mbOz0 +v8wNrNmehUyxEOQZlRPRdmgJJHObuOZ3Z49iWRJh26uvQLRYj0EdV9KkEKmSzxaF +1ZEAdLc369AgQGD33Ce9WGTtnROB6IIfFZULO5/wj/Ps32+T+jzZLIoGk+M/sng= +-----END ENCRYPTED PRIVATE KEY-----`), + password: []byte("test"), + }, + wantErr: false, + wantKey: func(tt *testing.T, got *ecdsa.PrivateKey) { + if got == nil { + tt.Error("ParseECPrivateKeyFromPEMWithPassword() is nil") + } + }, + }, + { + name: "Private key wrong password", + args: args{ + data: []byte( + `-----BEGIN ENCRYPTED PRIVATE KEY----- +MIHsMFcGCSqGSIb3DQEFDTBKMCkGCSqGSIb3DQEFDDAcBAgTz/KWaEQ7xwICCAAw +DAYIKoZIhvcNAgkFADAdBglghkgBZQMEASoEEEoMbQeGZBq+RJGRyY2N8PwEgZAY +U36vBRn5HB8zNSic75MfpGXWRVXki1qm29G/DU+E68hksvfbJlqqpL12fQ5mbOz0 +v8wNrNmehUyxEOQZlRPRdmgJJHObuOZ3Z49iWRJh26uvQLRYj0EdV9KkEKmSzxaF +1ZEAdLc369AgQGD33Ce9WGTtnROB6IIfFZULO5/wj/Ps32+T+jzZLIoGk+M/sng= +-----END ENCRYPTED PRIVATE KEY-----`), + password: []byte("nottest"), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotKey, err := ParseECPrivateKeyFromPEMWithPassword(tt.args.data, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("ParseECPrivateKeyFromPEMWithPassword() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantKey != nil { + tt.wantKey(t, gotKey) + } + }) + } +} + +func TestMarshalECPrivateKeyWithPassword(t *testing.T) { + pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + type args struct { + key *ecdsa.PrivateKey + password []byte + } + tests := []struct { + name string + args args + wantData func(*testing.T, []byte) + wantErr bool + }{ + { + name: "Marshal EC key", + args: args{ + key: pk, + password: []byte("test"), + }, + wantData: func(tt *testing.T, data []byte) { + if len(data) == 0 { + tt.Error("ParseECPrivateKeyFromPEMWithPassword() is empty") + } + }, + }, + { + name: "Error while marshalling EC key", + args: args{ + key: &ecdsa.PrivateKey{}, + password: []byte("test"), + }, + wantErr: true, + wantData: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotData, err := MarshalECPrivateKeyWithPassword(tt.args.key, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalECPrivateKeyWithPassword() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantData != nil { + tt.wantData(t, gotData) + } + }) + } +} + +func TestDecryptPEMBlock(t *testing.T) { + type args struct { + block *pem.Block + password []byte + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + name: "wrong type", + args: args{ + block: &pem.Block{ + Type: "SOMETYPE", + }, + }, + wantErr: true, + }, + { + name: "not ASN1", + args: args{ + block: &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: []byte{1, 2, 3}, + }, + }, + wantErr: true, + }, + { + name: "wrong algorithmn", + args: args{ + block: &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: func() []byte { + b, _ := asn1.Marshal(&EncryptedPrivateKeyInfo{ + EncryptionAlgorithm: EncryptionAlgorithmIdentifier{}, + }) + return b + }(), + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DecryptPEMBlock(tt.args.block, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("DecryptPEMBlock() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("DecryptPEMBlock() = %v, want %v", got, tt.want) + } + }) + } +}