From b135c891f36ea062cb72a9f8d34a7900e0bf3783 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Mon, 7 Mar 2022 20:31:16 +0100 Subject: [PATCH] Implementing `refresh_token` flow (#37) --- integration_test.go | 30 ++++++- server.go | 62 ++++++++++++++ server_test.go | 195 ++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 271 insertions(+), 16 deletions(-) diff --git a/integration_test.go b/integration_test.go index 178e342..ecc12aa 100644 --- a/integration_test.go +++ b/integration_test.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/golang-jwt/jwt/v4" oauth2 "github.com/oxisto/oauth2go" @@ -63,6 +64,8 @@ func TestThreeLeggedFlowPublicClient(t *testing.T) { form url.Values session *http.Cookie token *oauth2.Token + newToken *oauth2.Token + source oauth2.TokenSource code string challenge string verifier string @@ -160,11 +163,34 @@ func TestThreeLeggedFlowPublicClient(t *testing.T) { } if token.AccessToken == "" { - t.Error("Access token is empty", err) + t.Error("Access token is empty") } if token.RefreshToken == "" { - t.Error("Access token is empty", err) + t.Error("Access token is empty") + } + + // For some extra fun, let's use our refresh token by declaring our token expired + token.Expiry = time.Now().Add(-5 * time.Minute) + source = config.TokenSource(context.Background(), token) + + newToken, err = source.Token() + if err != nil { + t.Errorf("Error while fetching from token source: %v", err) + } + + if newToken.AccessToken == "" { + t.Error("Access token is empty") + } + + // Access tokens should be different + if newToken.AccessToken == token.AccessToken { + t.Error("New token is not different") + } + + // Refresh tokens should be the same + if newToken.RefreshToken != token.RefreshToken { + t.Error("Refresh token is different") } } diff --git a/server.go b/server.go index a4142b4..9f1a8e4 100644 --- a/server.go +++ b/server.go @@ -12,6 +12,7 @@ import ( "fmt" "net/http" "net/url" + "strconv" "time" "github.com/golang-jwt/jwt/v4" @@ -135,6 +136,8 @@ func (srv *AuthorizationServer) handleToken(w http.ResponseWriter, r *http.Reque srv.doClientCredentialsFlow(w, r) case "authorization_code": srv.doAuthorizationCodeFlow(w, r) + case "refresh_token": + srv.doRefreshTokenFlow(w, r) default: Error(w, "unsupported_grant_type", http.StatusBadRequest) return @@ -211,6 +214,65 @@ func (srv *AuthorizationServer) doAuthorizationCodeFlow(w http.ResponseWriter, r writeToken(w, token) } +// doRefreshTokenFlow implements refreshing an access token. +// See https://datatracker.ietf.org/doc/html/rfc6749#section-6). +func (srv *AuthorizationServer) doRefreshTokenFlow(w http.ResponseWriter, r *http.Request) { + var ( + err error + refreshToken string + claims jwt.RegisteredClaims + client *Client + token *Token + ) + + // Retrieve the token first, as we need it to find out which client this is + refreshToken = r.FormValue("refresh_token") + if refreshToken == "" { + Error(w, ErrorInvalidRequest, http.StatusBadRequest) + return + } + + // Try to parse it as a JWT + _, err = jwt.ParseWithClaims(refreshToken, &claims, func(t *jwt.Token) (interface{}, error) { + kid, _ := strconv.ParseInt(t.Header["kid"].(string), 10, 64) + + return srv.PublicKeys()[kid], nil + }) + if err != nil { + fmt.Printf("%+v", err) + Error(w, ErrorInvalidGrant, http.StatusBadRequest) + return + } + + // The subject contains our client ID. + client, err = srv.GetClient(claims.Subject) + if err != nil { + Error(w, ErrorInvalidClient, http.StatusUnauthorized) + return + } + + // If this is a public client, we can issue a new token + if client.ClientSecret == "" { + goto issue + } + + // Otherwise, we must check for authentication + client, err = srv.retrieveClient(r, false) + if err != nil { + Error(w, ErrorInvalidClient, http.StatusUnauthorized) + return + } + +issue: + token, err = srv.GenerateToken(client.ClientID, 0, -1) + if err != nil { + http.Error(w, "error while creating JWT", http.StatusInternalServerError) + return + } + + writeToken(w, token) +} + func (srv *AuthorizationServer) handleJWKS(w http.ResponseWriter, r *http.Request) { var ( keySet *JSONWebKeySet diff --git a/server_test.go b/server_test.go index 2b49ae5..b68eda6 100644 --- a/server_test.go +++ b/server_test.go @@ -3,6 +3,7 @@ package oauth2 import ( "crypto/ecdsa" "crypto/elliptic" + "crypto/rand" "encoding/json" "errors" "fmt" @@ -15,6 +16,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v4" "github.com/oxisto/oauth2go/internal/mock" ) @@ -38,18 +40,35 @@ var badSigningKey = ecdsa.PrivateKey{ }, } -var mockSigningKey = ecdsa.PrivateKey{ - D: big.NewInt(1), - PublicKey: ecdsa.PublicKey{ - X: big.NewInt(1), - Y: big.NewInt(2), - Curve: elliptic.P256(), - }, -} - +var testSigningKey *ecdsa.PrivateKey var testVerifier = "012345678901234567890123456789012345678901234567890123456789" var testChallenge = GenerateCodeChallenge(testVerifier) +// testRefreshTokenClientKID1MockSingingKey is a valid refresh token signed by mockSigningKey with the KID 1 +var testRefreshTokenClientKID1MockSingingKey string + +func init() { + var ( + err error + t *jwt.Token + ) + + testSigningKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + t = jwt.NewWithClaims(jwt.SigningMethodES256, jwt.RegisteredClaims{ + Subject: "client", + }) + t.Header["kid"] = fmt.Sprintf("%d", 1) + + testRefreshTokenClientKID1MockSingingKey, err = t.SignedString(testSigningKey) + if err != nil { + panic(err) + } +} + func TestAuthorizationServer_handleToken(t *testing.T) { type fields struct { clients []*Client @@ -542,6 +561,154 @@ func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) { } } +func TestAuthorizationServer_doRefreshTokenFlow(t *testing.T) { + type fields struct { + clients []*Client + signingKeys map[int]*ecdsa.PrivateKey + codes map[string]*codeInfo + } + type args struct { + r *http.Request + } + tests := []struct { + name string + fields fields + args args + wantCode int + wantBody string + }{ + { + name: "missing refresh token", + args: args{ + r: &http.Request{ + Method: "POST", + Header: http.Header{ + http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"}, + }, + Body: nil, + }, + }, + wantCode: http.StatusBadRequest, + wantBody: `{"error": "invalid_request"}`, + }, + { + name: "invalid refresh token", + args: args{ + r: &http.Request{ + Method: "POST", + Header: http.Header{ + http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"}, + }, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf("refresh_token=%s", "notatoken"))), + }, + }, + wantCode: http.StatusBadRequest, + wantBody: `{"error": "invalid_grant"}`, + }, + { + name: "wrong client", + fields: fields{ + clients: []*Client{ + { + ClientID: "notclient", + ClientSecret: "secret", + }, + }, + signingKeys: map[int]*ecdsa.PrivateKey{ + 0: &badSigningKey, + 1: testSigningKey, + }, + }, + args: args{ + r: &http.Request{ + Method: "POST", + Header: http.Header{ + http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"}, + }, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf("refresh_token=%s", testRefreshTokenClientKID1MockSingingKey))), + }, + }, + wantCode: http.StatusUnauthorized, + wantBody: `{"error": "invalid_client"}`, + }, + { + name: "missing authentication for confidential client", + fields: fields{ + clients: []*Client{ + { + ClientID: "client", + ClientSecret: "secret", + }, + }, + signingKeys: map[int]*ecdsa.PrivateKey{ + 0: &badSigningKey, + 1: testSigningKey, + }, + }, + args: args{ + r: &http.Request{ + Method: "POST", + Header: http.Header{ + http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"}, + }, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf("refresh_token=%s", testRefreshTokenClientKID1MockSingingKey))), + }, + }, + wantCode: http.StatusUnauthorized, + wantBody: `{"error": "invalid_client"}`, + }, + { + name: "problem with JWT creation", + fields: fields{ + clients: []*Client{ + { + ClientID: "client", + ClientSecret: "", + }, + }, + signingKeys: map[int]*ecdsa.PrivateKey{ + 0: &badSigningKey, + 1: testSigningKey, + }, + }, + args: args{ + r: &http.Request{ + Method: "POST", + Header: http.Header{ + http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"}, + }, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf("refresh_token=%s", testRefreshTokenClientKID1MockSingingKey))), + }, + }, + wantCode: http.StatusInternalServerError, + wantBody: `error while creating JWT`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := &AuthorizationServer{ + clients: tt.fields.clients, + signingKeys: tt.fields.signingKeys, + codes: tt.fields.codes, + } + + rr := httptest.NewRecorder() + srv.doRefreshTokenFlow(rr, tt.args.r) + + gotCode := rr.Code + if gotCode != tt.wantCode { + t.Errorf("AuthorizationServer.doRefreshTokenFlow() code = %v, wantCode %v", gotCode, tt.wantCode) + } + + gotBody := strings.Trim(rr.Body.String(), "\n") + if gotBody != tt.wantBody { + t.Errorf("AuthorizationServer.doRefreshTokenFlow() body = %v, wantBody %v", gotBody, tt.wantBody) + } + }) + } +} + func TestAuthorizationServer_ValidateCode(t *testing.T) { type fields struct { clients []*Client @@ -670,7 +837,7 @@ func TestAuthorizationServer_GenerateToken(t *testing.T) { name: "invalid key ID", fields: fields{ signingKeys: map[int]*ecdsa.PrivateKey{ - 0: &mockSigningKey, + 0: testSigningKey, }, }, args: args{ @@ -684,7 +851,7 @@ func TestAuthorizationServer_GenerateToken(t *testing.T) { name: "bad refresh key", fields: fields{ signingKeys: map[int]*ecdsa.PrivateKey{ - 0: &mockSigningKey, + 0: testSigningKey, 1: &badSigningKey, }, }, @@ -700,7 +867,7 @@ func TestAuthorizationServer_GenerateToken(t *testing.T) { name: "invalid refresh key ID", fields: fields{ signingKeys: map[int]*ecdsa.PrivateKey{ - 0: &mockSigningKey, + 0: testSigningKey, }, }, args: args{ @@ -750,7 +917,7 @@ func TestNewServer(t *testing.T) { opts: []AuthorizationServerOption{ WithSigningKeysFunc(func() (keys map[int]*ecdsa.PrivateKey) { return map[int]*ecdsa.PrivateKey{ - 0: &mockSigningKey, + 0: testSigningKey, } })}, }, @@ -758,7 +925,7 @@ func TestNewServer(t *testing.T) { clients: []*Client{}, codes: map[string]*codeInfo{}, signingKeys: map[int]*ecdsa.PrivateKey{ - 0: &mockSigningKey, + 0: testSigningKey, }, }, },