-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented basic storage for the private key used in the server.
This code is donated from https://github.com/clouditor/clouditor/tree/main/internal/auth, which we will move upstream to this project. This also cleans up the existing code a little bit.
- Loading branch information
Showing
4 changed files
with
730 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
package storage | ||
|
||
import ( | ||
"crypto/ecdsa" | ||
"crypto/elliptic" | ||
"crypto/rand" | ||
"errors" | ||
"fmt" | ||
"log/slog" | ||
"os" | ||
"os/user" | ||
"path/filepath" | ||
"strings" | ||
) | ||
|
||
const ( | ||
// DefaultPrivateKeySaveOnCreate specifies whether a created private key | ||
// will be saved. This is useful to turn off in unit tests, where we only | ||
// want a temporary key. | ||
DefaultPrivateKeySaveOnCreate = true | ||
|
||
// DefaultPrivateKeyPassword is the default password to protect the private | ||
// key. | ||
DefaultPrivateKeyPassword = "changeme" | ||
|
||
// DefaultPrivateKeyPath is the default path for the private key. | ||
DefaultPrivateKeyPath = DefaultConfigDirectory + "/private.key" | ||
|
||
// DefaultConfigDirectory is the default path for the oauth2go | ||
// configuration, such as keys. | ||
DefaultConfigDirectory = "~/.oauth2go" | ||
) | ||
|
||
type keyLoader struct { | ||
path string | ||
password string | ||
saveOnCreate bool | ||
} | ||
|
||
// LoadSigningKeys implements a singing keys func for our internal authorization server | ||
func LoadSigningKeys(path string, password string, saveOnCreate bool) map[int]*ecdsa.PrivateKey { | ||
// create a key loader with our arguments | ||
loader := keyLoader{ | ||
path: path, | ||
password: password, | ||
saveOnCreate: saveOnCreate, | ||
} | ||
|
||
return map[int]*ecdsa.PrivateKey{ | ||
0: loader.LoadKey(), | ||
} | ||
} | ||
|
||
func (l *keyLoader) LoadKey() (key *ecdsa.PrivateKey) { | ||
var ( | ||
err error | ||
) | ||
|
||
// Try to load the key from the given path | ||
key, err = loadKeyFromFile(l.path, []byte(l.password)) | ||
if err != nil { | ||
key = l.recoverFromLoadApiKeyError(err, l.path == DefaultPrivateKeyPath) | ||
} | ||
|
||
return | ||
} | ||
|
||
// recoverFromLoadApiKeyError tries to recover from an error during key loading. | ||
// We treat different errors differently. For example if the path is the default | ||
// path and the error is [os.ErrNotExist], this could be just the first start of | ||
// Clouditor. So we only treat this as an information that we will create a new | ||
// key, which we will save, based on the config. | ||
// | ||
// If the user specifies a custom path and this one does not exist, we will | ||
// report an error here. | ||
func (l *keyLoader) recoverFromLoadApiKeyError(err error, defaultPath bool) (key *ecdsa.PrivateKey) { | ||
// In any case, create a new temporary API key | ||
key, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) | ||
|
||
if defaultPath && errors.Is(err, os.ErrNotExist) { | ||
slog.Info("API key does not exist at the default location yet. We will create a new one") | ||
|
||
if l.saveOnCreate { | ||
// Also make sure that default config path exists | ||
err = ensureConfigFolderExistence() | ||
// Error while error handling, meh | ||
if err != nil { | ||
return | ||
} | ||
|
||
// Also save the key in this case, so we can load it next time | ||
err = saveKeyToFile(key, l.path, l.password) | ||
|
||
// Error while error handling, meh | ||
if err != nil { | ||
slog.Error("Error while saving the new API key", "err", err) | ||
} | ||
} | ||
} else if err != nil { | ||
slog.Error("Could not load key from file, continuing with a temporary key", "err", err) | ||
} | ||
|
||
return key | ||
} | ||
|
||
// loadKeyFromFile loads an ecdsa.PrivateKey from a path. The key must in PEM format and protected by | ||
// a password using PKCS#8 with PBES2. | ||
func loadKeyFromFile(path string, password []byte) (key *ecdsa.PrivateKey, err error) { | ||
var ( | ||
keyFile string | ||
) | ||
|
||
keyFile, err = expandPath(path) | ||
if err != nil { | ||
return nil, fmt.Errorf("error while expanding path: %w", err) | ||
} | ||
|
||
if _, err = os.Stat(keyFile); os.IsNotExist(err) { | ||
return nil, fmt.Errorf("file does not exist (yet): %w", err) | ||
} | ||
|
||
// Check, if we already have a persisted API key | ||
data, err := os.ReadFile(keyFile) | ||
if err != nil { | ||
return nil, fmt.Errorf("error while reading key: %w", err) | ||
} | ||
|
||
key, err = ParseECPrivateKeyFromPEMWithPassword(data, password) | ||
if err != nil { | ||
return nil, fmt.Errorf("error while parsing private key: %w", err) | ||
} | ||
|
||
return key, nil | ||
} | ||
|
||
// saveKeyToFile saves an ecdsa.PrivateKey to a path. The key will be saved in | ||
// PEM format and protected by a password using PKCS#8 with PBES2. | ||
func saveKeyToFile(apiKey *ecdsa.PrivateKey, keyPath string, password string) (err error) { | ||
keyPath, err = expandPath(keyPath) | ||
if err != nil { | ||
return fmt.Errorf("error while expanding path: %w", err) | ||
} | ||
|
||
// Check, if we already have a persisted API key | ||
f, err := os.OpenFile(keyPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) | ||
if err != nil { | ||
return fmt.Errorf("error while opening the file: %w", err) | ||
} | ||
defer func() { | ||
_ = f.Close() | ||
}() | ||
|
||
data, err := MarshalECPrivateKeyWithPassword(apiKey, []byte(password)) | ||
if err != nil { | ||
return fmt.Errorf("error while marshalling private key: %w", err) | ||
} | ||
|
||
_, err = f.Write(data) | ||
if err != nil { | ||
return fmt.Errorf("error while writing file content: %w", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// expandPath expands a path that possible contains a tilde (~) character into | ||
// the home directory of the user | ||
func expandPath(path string) (out string, err error) { | ||
var ( | ||
u *user.User | ||
) | ||
|
||
// Fetch the current user home directory | ||
u, err = user.Current() | ||
if err != nil { | ||
return path, fmt.Errorf("could not find retrieve current user: %w", err) | ||
} | ||
|
||
if path == "~" { | ||
return u.HomeDir, nil | ||
} else if strings.HasPrefix(path, "~") { | ||
// We only allow ~ at the beginning of the path | ||
return filepath.Join(u.HomeDir, path[2:]), nil | ||
} | ||
|
||
return path, nil | ||
} | ||
|
||
// ensureConfigesFolderExistence ensures that the config folder exists. | ||
func ensureConfigFolderExistence() (err error) { | ||
var configPath string | ||
|
||
// Expand the config directory, if it contains any ~ characters. | ||
configPath, err = expandPath(DefaultConfigDirectory) | ||
if err != nil { | ||
// Directly return the error here, no need for additional wrapping | ||
return err | ||
} | ||
|
||
// Create the directory, if it not exists | ||
_, err = os.Stat(configPath) | ||
if errors.Is(err, os.ErrNotExist) { | ||
err = os.Mkdir(configPath, os.ModePerm) | ||
if err != nil { | ||
// Directly return the error here, no need for additional wrapping | ||
return err | ||
} | ||
} | ||
|
||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package storage | ||
|
||
import ( | ||
"crypto/ecdsa" | ||
"io" | ||
"os" | ||
"testing" | ||
) | ||
|
||
func Test_keyLoader_recoverFromLoadApiKeyError(t *testing.T) { | ||
var tmpFile, _ = os.CreateTemp("", "api.key") | ||
// Close it immediately , since we want to write to it | ||
tmpFile.Close() | ||
|
||
defer func() { | ||
os.Remove(tmpFile.Name()) | ||
}() | ||
|
||
type fields struct { | ||
path string | ||
password string | ||
saveOnCreate bool | ||
} | ||
type args struct { | ||
err error | ||
defaultPath bool | ||
} | ||
tests := []struct { | ||
name string | ||
fields fields | ||
args args | ||
wantKey func(*testing.T, *ecdsa.PrivateKey) | ||
}{ | ||
{ | ||
name: "Could not load key from custom path", | ||
fields: fields{ | ||
saveOnCreate: false, | ||
path: "doesnotexist", | ||
password: "test", | ||
}, | ||
args: args{ | ||
err: os.ErrNotExist, | ||
defaultPath: false, | ||
}, | ||
wantKey: func(tt *testing.T, got *ecdsa.PrivateKey) { | ||
if got == nil { | ||
tt.Error("keyLoader.recoverFromLoadApiKeyError() is nil") | ||
} | ||
}, | ||
}, | ||
{ | ||
name: "Could not load key from default path and save it", | ||
fields: fields{ | ||
saveOnCreate: true, | ||
path: tmpFile.Name(), | ||
password: "test", | ||
}, | ||
args: args{ | ||
err: os.ErrNotExist, | ||
defaultPath: true, | ||
}, | ||
wantKey: func(tt *testing.T, got *ecdsa.PrivateKey) { | ||
if got == nil { | ||
tt.Error("keyLoader.recoverFromLoadApiKeyError() is nil") | ||
} | ||
|
||
f, _ := os.OpenFile(tmpFile.Name(), os.O_RDONLY, 0600) | ||
// Our tmp file should also contain something now | ||
data, _ := io.ReadAll(f) | ||
|
||
if len(data) == 0 { | ||
tt.Error("keyLoader.recoverFromLoadApiKeyError() did not write key on file") | ||
} | ||
}, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
l := &keyLoader{ | ||
path: tt.fields.path, | ||
password: tt.fields.password, | ||
saveOnCreate: tt.fields.saveOnCreate, | ||
} | ||
gotKey := l.recoverFromLoadApiKeyError(tt.args.err, tt.args.defaultPath) | ||
|
||
if tt.wantKey != nil { | ||
tt.wantKey(t, gotKey) | ||
} | ||
}) | ||
} | ||
} |
Oops, something went wrong.