Skip to content

Commit

Permalink
Exporting GenerateToken (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto authored Mar 3, 2022
1 parent 95c1c63 commit 3d508d3
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 68 deletions.
102 changes: 56 additions & 46 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type AuthorizationServer struct {
clients []*Client

// our signing keys
signingKeys []*ecdsa.PrivateKey
signingKeys map[int]*ecdsa.PrivateKey

// our codes and their expiry time and challenge
codes map[string]*codeInfo
Expand Down Expand Up @@ -88,7 +88,7 @@ func NewServer(addr string, opts ...AuthorizationServerOption) *AuthorizationSer
// Create a new private key
var signingKey, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)

srv.signingKeys = []*ecdsa.PrivateKey{signingKey}
srv.signingKeys = map[int]*ecdsa.PrivateKey{0: signingKey}

mux.HandleFunc("/token", srv.handleToken)
mux.HandleFunc("/.well-known/jwks.json", srv.handleJWKS)
Expand Down Expand Up @@ -149,7 +149,7 @@ func (srv *AuthorizationServer) doClientCredentialsFlow(w http.ResponseWriter, r
return
}

token, err = generateToken(client.ClientID, srv.signingKeys[0], 0, nil, 0)
token, err = srv.GenerateToken(client.ClientID, 0, -1)
if err != nil {
http.Error(w, "error while creating JWT", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -193,7 +193,7 @@ func (srv *AuthorizationServer) doAuthorizationCodeFlow(w http.ResponseWriter, r
return
}

token, err = generateToken(client.ClientID, srv.signingKeys[0], 0, srv.signingKeys[0], 0)
token, err = srv.GenerateToken(client.ClientID, 0, 0)
if err != nil {
http.Error(w, "error while creating JWT", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -314,6 +314,58 @@ func (srv *AuthorizationServer) ValidateCode(verifier string, code string) bool
return true
}

// GenerateToken generates a Token (comprising at least an acesss token) for a specific client,
// as specified by its ID. A signingKey needs to be specified, otherwise an error is thrown.
// Optionally, if a refreshKey is specified, that key is used to also create a refresh token.
func (srv *AuthorizationServer) GenerateToken(clientID string, signingKeyID int, refreshKeyID int) (token *Token, err error) {
var (
expiry = time.Now().Add(24 * time.Hour)
signingKey *ecdsa.PrivateKey
refreshKey *ecdsa.PrivateKey
ok bool
)

token = new(oauth2.Token)

token.TokenType = "Bearer"
token.Expiry = expiry

signingKey, ok = srv.signingKeys[signingKeyID]
if !ok {
return nil, errors.New("invalid key ID")
}

// Create a new JWT
t := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.RegisteredClaims{
Subject: clientID,
ExpiresAt: jwt.NewNumericDate(expiry),
})
t.Header["kid"] = fmt.Sprintf("%d", signingKeyID)

if token.AccessToken, err = t.SignedString(signingKey); err != nil {
return nil, err
}

// Create a refresh token, if we have a key for it
if refreshKeyID != -1 {
refreshKey, ok = srv.signingKeys[refreshKeyID]
if !ok {
return nil, errors.New("invalid key ID")
}

t = jwt.NewWithClaims(jwt.SigningMethodES256, jwt.RegisteredClaims{
Subject: clientID,
})
t.Header["kid"] = fmt.Sprintf("%d", refreshKeyID)

if token.RefreshToken, err = t.SignedString(refreshKey); err != nil {
return nil, err
}
}

return
}

func Error(w http.ResponseWriter, error string, statusCode int) {
w.Header().Set("Content-Type", "application/json")

Expand Down Expand Up @@ -350,48 +402,6 @@ func GenerateSecret() string {
return base64.RawStdEncoding.EncodeToString(b)
}

// generateToken generates a Token (comprising at least an acesss token) for a specific client,
// as specified by its ID. A signingKey needs to be specified, otherwise an error is thrown.
// Optionally, if a refreshKey is specified, that key is used to also create a refresh token.
func generateToken(clientID string,
signingKey *ecdsa.PrivateKey,
signingKeyID int,
refreshKey *ecdsa.PrivateKey,
refreshKeyID int,
) (token *Token, err error) {
var expiry = time.Now().Add(24 * time.Hour)

token = new(oauth2.Token)

token.TokenType = "Bearer"
token.Expiry = expiry

// Create a new JWT
t := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.RegisteredClaims{
Subject: clientID,
ExpiresAt: jwt.NewNumericDate(expiry),
})
t.Header["kid"] = fmt.Sprintf("%d", signingKeyID)

if token.AccessToken, err = t.SignedString(signingKey); err != nil {
return nil, err
}

// Create a refresh token, if we have a key for it
if refreshKey != nil {
t = jwt.NewWithClaims(jwt.SigningMethodES256, jwt.RegisteredClaims{
Subject: clientID,
})
t.Header["kid"] = fmt.Sprintf("%d", refreshKeyID)

if token.RefreshToken, err = t.SignedString(refreshKey); err != nil {
return nil, err
}
}

return
}

func GenerateCodeChallenge(verifier string) string {
var digest = sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(digest[:])
Expand Down
91 changes: 69 additions & 22 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ var testChallenge = GenerateCodeChallenge(testVerifier)
func TestAuthorizationServer_handleToken(t *testing.T) {
type fields struct {
clients []*Client
signingKeys []*ecdsa.PrivateKey
signingKeys map[int]*ecdsa.PrivateKey
}
type args struct {
r *http.Request
Expand Down Expand Up @@ -103,7 +103,7 @@ func TestAuthorizationServer_handleToken(t *testing.T) {
func TestAuthorizationServer_retrieveClient(t *testing.T) {
type fields struct {
clients []*Client
signingKeys []*ecdsa.PrivateKey
signingKeys map[int]*ecdsa.PrivateKey
}
type args struct {
r *http.Request
Expand Down Expand Up @@ -200,7 +200,7 @@ func TestAuthorizationServer_retrieveClient(t *testing.T) {
func TestAuthorizationServer_handleJWKS(t *testing.T) {
type fields struct {
clients []*Client
signingKeys []*ecdsa.PrivateKey
signingKeys map[int]*ecdsa.PrivateKey
}
type args struct {
r *http.Request
Expand All @@ -215,8 +215,8 @@ func TestAuthorizationServer_handleJWKS(t *testing.T) {
{
name: "retrieve JWKS with GET",
fields: fields{
signingKeys: []*ecdsa.PrivateKey{
{
signingKeys: map[int]*ecdsa.PrivateKey{
0: {
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
X: big.NewInt(1),
Expand Down Expand Up @@ -325,7 +325,7 @@ func Test_writeJSON(t *testing.T) {
func TestAuthorizationServer_doClientCredentialsFlow(t *testing.T) {
type fields struct {
clients []*Client
signingKeys []*ecdsa.PrivateKey
signingKeys map[int]*ecdsa.PrivateKey
}
type args struct {
r *http.Request
Expand Down Expand Up @@ -359,8 +359,8 @@ func TestAuthorizationServer_doClientCredentialsFlow(t *testing.T) {
ClientSecret: "secret",
},
},
signingKeys: []*ecdsa.PrivateKey{
&badSigningKey,
signingKeys: map[int]*ecdsa.PrivateKey{
0: &badSigningKey,
},
},
args: args{
Expand Down Expand Up @@ -402,7 +402,7 @@ func TestAuthorizationServer_doClientCredentialsFlow(t *testing.T) {
func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) {
type fields struct {
clients []*Client
signingKeys []*ecdsa.PrivateKey
signingKeys map[int]*ecdsa.PrivateKey
codes map[string]*codeInfo
}
type args struct {
Expand Down Expand Up @@ -500,8 +500,8 @@ func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) {
challenge: testChallenge,
},
},
signingKeys: []*ecdsa.PrivateKey{
&badSigningKey,
signingKeys: map[int]*ecdsa.PrivateKey{
0: &badSigningKey,
},
},
args: args{
Expand Down Expand Up @@ -545,7 +545,7 @@ func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) {
func TestAuthorizationServer_ValidateCode(t *testing.T) {
type fields struct {
clients []*Client
signingKeys []*ecdsa.PrivateKey
signingKeys map[int]*ecdsa.PrivateKey
codes map[string]*codeInfo
}
type args struct {
Expand Down Expand Up @@ -634,38 +634,79 @@ func TestAuthorizationServer_ValidateCode(t *testing.T) {
}
}

func Test_generateToken(t *testing.T) {
func TestAuthorizationServer_GenerateToken(t *testing.T) {
type fields struct {
clients []*Client
signingKeys map[int]*ecdsa.PrivateKey
codes map[string]*codeInfo
}
type args struct {
clientID string
signingKey *ecdsa.PrivateKey
signingKeyID int
refreshKey *ecdsa.PrivateKey
refreshKeyID int
}
tests := []struct {
name string
fields fields
args args
wantToken *Token
wantErr bool
}{
{
name: "bad signing key",
fields: fields{
signingKeys: map[int]*ecdsa.PrivateKey{
0: &badSigningKey,
},
},
args: args{
clientID: "client",
signingKey: &badSigningKey,
signingKeyID: 0,
},
wantToken: nil,
wantErr: true,
},
{
name: "invalid key ID",
fields: fields{
signingKeys: map[int]*ecdsa.PrivateKey{
0: &mockSigningKey,
},
},
args: args{
clientID: "client",
signingKeyID: 1,
},
wantToken: nil,
wantErr: true,
},
{
name: "bad refresh key",
fields: fields{
signingKeys: map[int]*ecdsa.PrivateKey{
0: &mockSigningKey,
1: &badSigningKey,
},
},
args: args{
clientID: "client",
signingKey: &mockSigningKey,
signingKeyID: 0,
refreshKey: &badSigningKey,
refreshKeyID: 0,
refreshKeyID: 1,
},
wantToken: nil,
wantErr: true,
},
{
name: "invalid refresh key ID",
fields: fields{
signingKeys: map[int]*ecdsa.PrivateKey{
0: &mockSigningKey,
},
},
args: args{
clientID: "client",
signingKeyID: 0,
refreshKeyID: 1,
},
wantToken: nil,
wantErr: true,
Expand All @@ -674,14 +715,20 @@ func Test_generateToken(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotToken, err := generateToken(tt.args.clientID, tt.args.signingKey, tt.args.signingKeyID, tt.args.refreshKey, tt.args.refreshKeyID)
srv := &AuthorizationServer{
clients: tt.fields.clients,
signingKeys: tt.fields.signingKeys,
codes: tt.fields.codes,
}

gotToken, err := srv.GenerateToken(tt.args.clientID, tt.args.signingKeyID, tt.args.refreshKeyID)
if (err != nil) != tt.wantErr {
t.Errorf("generateToken() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("AuthorizationServer.GenerateToken() error = %v, wantErr %v", err, tt.wantErr)
return
}

if !reflect.DeepEqual(gotToken, tt.wantToken) {
t.Errorf("generateToken() = %v, want %v", gotToken, tt.wantToken)
t.Errorf("AuthorizationServer.GenerateToken() = %v, want %v", gotToken, tt.wantToken)
}
})
}
Expand Down

0 comments on commit 3d508d3

Please sign in to comment.