From 258f666d3e7dda8a2f44d5882166b0bb5448849e Mon Sep 17 00:00:00 2001 From: Christian Banse <oxisto@aybaze.com> Date: Fri, 4 Mar 2022 15:05:51 +0100 Subject: [PATCH] Support external loading of singing keys (#33) Fixes #32 --- server.go | 22 ++++++++++++++++++---- server_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/server.go b/server.go index e014815..e351a15 100644 --- a/server.go +++ b/server.go @@ -50,6 +50,8 @@ type AuthorizationServer struct { type AuthorizationServerOption func(srv *AuthorizationServer) +type signingKeysFunc func() (keys map[int]*ecdsa.PrivateKey) + type CodeIssuer interface { IssueCode(challenge string) string ValidateCode(verifier string, code string) bool @@ -69,6 +71,12 @@ func WithClient( } } +func WithSigningKeysFunc(f signingKeysFunc) AuthorizationServerOption { + return func(srv *AuthorizationServer) { + srv.signingKeys = f() + } +} + func NewServer(addr string, opts ...AuthorizationServerOption) *AuthorizationServer { mux := http.NewServeMux() @@ -85,10 +93,9 @@ func NewServer(addr string, opts ...AuthorizationServerOption) *AuthorizationSer o(srv) } - // Create a new private key - var signingKey, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - - srv.signingKeys = map[int]*ecdsa.PrivateKey{0: signingKey} + if srv.signingKeys == nil { + srv.signingKeys = generateSigningKeys() + } mux.HandleFunc("/token", srv.handleToken) mux.HandleFunc("/.well-known/jwks.json", srv.handleJWKS) @@ -406,3 +413,10 @@ func GenerateCodeChallenge(verifier string) string { var digest = sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(digest[:]) } + +// generateSigningKeys generates a set of signing keys +func generateSigningKeys() map[int]*ecdsa.PrivateKey { + var signingKey, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + return map[int]*ecdsa.PrivateKey{0: signingKey} +} diff --git a/server_test.go b/server_test.go index 6ab24f5..2b49ae5 100644 --- a/server_test.go +++ b/server_test.go @@ -733,3 +733,48 @@ func TestAuthorizationServer_GenerateToken(t *testing.T) { }) } } + +func TestNewServer(t *testing.T) { + type args struct { + addr string + opts []AuthorizationServerOption + } + tests := []struct { + name string + args args + want *AuthorizationServer + }{ + { + name: "with signing keys func", + args: args{ + opts: []AuthorizationServerOption{ + WithSigningKeysFunc(func() (keys map[int]*ecdsa.PrivateKey) { + return map[int]*ecdsa.PrivateKey{ + 0: &mockSigningKey, + } + })}, + }, + want: &AuthorizationServer{ + clients: []*Client{}, + codes: map[string]*codeInfo{}, + signingKeys: map[int]*ecdsa.PrivateKey{ + 0: &mockSigningKey, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewServer(tt.args.addr, tt.args.opts...) + + // Ignore Server.Handler in comparison because we create a new ServeMux + got.Handler = nil + tt.want.Handler = nil + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewServer() = %v, want %v", got, tt.want) + } + }) + } +}