From 258f666d3e7dda8a2f44d5882166b0bb5448849e Mon Sep 17 00:00:00 2001
From: Christian Banse <oxisto@aybaze.com>
Date: Fri, 4 Mar 2022 15:05:51 +0100
Subject: [PATCH] Support external loading of singing keys (#33)

Fixes #32
---
 server.go      | 22 ++++++++++++++++++----
 server_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 63 insertions(+), 4 deletions(-)

diff --git a/server.go b/server.go
index e014815..e351a15 100644
--- a/server.go
+++ b/server.go
@@ -50,6 +50,8 @@ type AuthorizationServer struct {
 
 type AuthorizationServerOption func(srv *AuthorizationServer)
 
+type signingKeysFunc func() (keys map[int]*ecdsa.PrivateKey)
+
 type CodeIssuer interface {
 	IssueCode(challenge string) string
 	ValidateCode(verifier string, code string) bool
@@ -69,6 +71,12 @@ func WithClient(
 	}
 }
 
+func WithSigningKeysFunc(f signingKeysFunc) AuthorizationServerOption {
+	return func(srv *AuthorizationServer) {
+		srv.signingKeys = f()
+	}
+}
+
 func NewServer(addr string, opts ...AuthorizationServerOption) *AuthorizationServer {
 	mux := http.NewServeMux()
 
@@ -85,10 +93,9 @@ func NewServer(addr string, opts ...AuthorizationServerOption) *AuthorizationSer
 		o(srv)
 	}
 
-	// Create a new private key
-	var signingKey, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-
-	srv.signingKeys = map[int]*ecdsa.PrivateKey{0: signingKey}
+	if srv.signingKeys == nil {
+		srv.signingKeys = generateSigningKeys()
+	}
 
 	mux.HandleFunc("/token", srv.handleToken)
 	mux.HandleFunc("/.well-known/jwks.json", srv.handleJWKS)
@@ -406,3 +413,10 @@ func GenerateCodeChallenge(verifier string) string {
 	var digest = sha256.Sum256([]byte(verifier))
 	return base64.RawURLEncoding.EncodeToString(digest[:])
 }
+
+// generateSigningKeys generates a set of signing keys
+func generateSigningKeys() map[int]*ecdsa.PrivateKey {
+	var signingKey, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+
+	return map[int]*ecdsa.PrivateKey{0: signingKey}
+}
diff --git a/server_test.go b/server_test.go
index 6ab24f5..2b49ae5 100644
--- a/server_test.go
+++ b/server_test.go
@@ -733,3 +733,48 @@ func TestAuthorizationServer_GenerateToken(t *testing.T) {
 		})
 	}
 }
+
+func TestNewServer(t *testing.T) {
+	type args struct {
+		addr string
+		opts []AuthorizationServerOption
+	}
+	tests := []struct {
+		name string
+		args args
+		want *AuthorizationServer
+	}{
+		{
+			name: "with signing keys func",
+			args: args{
+				opts: []AuthorizationServerOption{
+					WithSigningKeysFunc(func() (keys map[int]*ecdsa.PrivateKey) {
+						return map[int]*ecdsa.PrivateKey{
+							0: &mockSigningKey,
+						}
+					})},
+			},
+			want: &AuthorizationServer{
+				clients: []*Client{},
+				codes:   map[string]*codeInfo{},
+				signingKeys: map[int]*ecdsa.PrivateKey{
+					0: &mockSigningKey,
+				},
+			},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got := NewServer(tt.args.addr, tt.args.opts...)
+
+			// Ignore Server.Handler in comparison because we create a new ServeMux
+			got.Handler = nil
+			tt.want.Handler = nil
+
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("NewServer() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}