Skip to content

Commit

Permalink
Implementing refresh_token flow (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto authored Mar 7, 2022
1 parent 1742760 commit b135c89
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 16 deletions.
30 changes: 28 additions & 2 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/golang-jwt/jwt/v4"
oauth2 "github.com/oxisto/oauth2go"
Expand Down Expand Up @@ -63,6 +64,8 @@ func TestThreeLeggedFlowPublicClient(t *testing.T) {
form url.Values
session *http.Cookie
token *oauth2.Token
newToken *oauth2.Token
source oauth2.TokenSource
code string
challenge string
verifier string
Expand Down Expand Up @@ -160,11 +163,34 @@ func TestThreeLeggedFlowPublicClient(t *testing.T) {
}

if token.AccessToken == "" {
t.Error("Access token is empty", err)
t.Error("Access token is empty")
}

if token.RefreshToken == "" {
t.Error("Access token is empty", err)
t.Error("Access token is empty")
}

// For some extra fun, let's use our refresh token by declaring our token expired
token.Expiry = time.Now().Add(-5 * time.Minute)
source = config.TokenSource(context.Background(), token)

newToken, err = source.Token()
if err != nil {
t.Errorf("Error while fetching from token source: %v", err)
}

if newToken.AccessToken == "" {
t.Error("Access token is empty")
}

// Access tokens should be different
if newToken.AccessToken == token.AccessToken {
t.Error("New token is not different")
}

// Refresh tokens should be the same
if newToken.RefreshToken != token.RefreshToken {
t.Error("Refresh token is different")
}
}

Expand Down
62 changes: 62 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"net/http"
"net/url"
"strconv"
"time"

"github.com/golang-jwt/jwt/v4"
Expand Down Expand Up @@ -135,6 +136,8 @@ func (srv *AuthorizationServer) handleToken(w http.ResponseWriter, r *http.Reque
srv.doClientCredentialsFlow(w, r)
case "authorization_code":
srv.doAuthorizationCodeFlow(w, r)
case "refresh_token":
srv.doRefreshTokenFlow(w, r)
default:
Error(w, "unsupported_grant_type", http.StatusBadRequest)
return
Expand Down Expand Up @@ -211,6 +214,65 @@ func (srv *AuthorizationServer) doAuthorizationCodeFlow(w http.ResponseWriter, r
writeToken(w, token)
}

// doRefreshTokenFlow implements refreshing an access token.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-6).
func (srv *AuthorizationServer) doRefreshTokenFlow(w http.ResponseWriter, r *http.Request) {
var (
err error
refreshToken string
claims jwt.RegisteredClaims
client *Client
token *Token
)

// Retrieve the token first, as we need it to find out which client this is
refreshToken = r.FormValue("refresh_token")
if refreshToken == "" {
Error(w, ErrorInvalidRequest, http.StatusBadRequest)
return
}

// Try to parse it as a JWT
_, err = jwt.ParseWithClaims(refreshToken, &claims, func(t *jwt.Token) (interface{}, error) {
kid, _ := strconv.ParseInt(t.Header["kid"].(string), 10, 64)

return srv.PublicKeys()[kid], nil
})
if err != nil {
fmt.Printf("%+v", err)
Error(w, ErrorInvalidGrant, http.StatusBadRequest)
return
}

// The subject contains our client ID.
client, err = srv.GetClient(claims.Subject)
if err != nil {
Error(w, ErrorInvalidClient, http.StatusUnauthorized)
return
}

// If this is a public client, we can issue a new token
if client.ClientSecret == "" {
goto issue
}

// Otherwise, we must check for authentication
client, err = srv.retrieveClient(r, false)
if err != nil {
Error(w, ErrorInvalidClient, http.StatusUnauthorized)
return
}

issue:
token, err = srv.GenerateToken(client.ClientID, 0, -1)
if err != nil {
http.Error(w, "error while creating JWT", http.StatusInternalServerError)
return
}

writeToken(w, token)
}

func (srv *AuthorizationServer) handleJWKS(w http.ResponseWriter, r *http.Request) {
var (
keySet *JSONWebKeySet
Expand Down
195 changes: 181 additions & 14 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oauth2
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
Expand All @@ -15,6 +16,7 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/oxisto/oauth2go/internal/mock"
)

Expand All @@ -38,18 +40,35 @@ var badSigningKey = ecdsa.PrivateKey{
},
}

var mockSigningKey = ecdsa.PrivateKey{
D: big.NewInt(1),
PublicKey: ecdsa.PublicKey{
X: big.NewInt(1),
Y: big.NewInt(2),
Curve: elliptic.P256(),
},
}

var testSigningKey *ecdsa.PrivateKey
var testVerifier = "012345678901234567890123456789012345678901234567890123456789"
var testChallenge = GenerateCodeChallenge(testVerifier)

// testRefreshTokenClientKID1MockSingingKey is a valid refresh token signed by mockSigningKey with the KID 1
var testRefreshTokenClientKID1MockSingingKey string

func init() {
var (
err error
t *jwt.Token
)

testSigningKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}

t = jwt.NewWithClaims(jwt.SigningMethodES256, jwt.RegisteredClaims{
Subject: "client",
})
t.Header["kid"] = fmt.Sprintf("%d", 1)

testRefreshTokenClientKID1MockSingingKey, err = t.SignedString(testSigningKey)
if err != nil {
panic(err)
}
}

func TestAuthorizationServer_handleToken(t *testing.T) {
type fields struct {
clients []*Client
Expand Down Expand Up @@ -542,6 +561,154 @@ func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) {
}
}

func TestAuthorizationServer_doRefreshTokenFlow(t *testing.T) {
type fields struct {
clients []*Client
signingKeys map[int]*ecdsa.PrivateKey
codes map[string]*codeInfo
}
type args struct {
r *http.Request
}
tests := []struct {
name string
fields fields
args args
wantCode int
wantBody string
}{
{
name: "missing refresh token",
args: args{
r: &http.Request{
Method: "POST",
Header: http.Header{
http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"},
},
Body: nil,
},
},
wantCode: http.StatusBadRequest,
wantBody: `{"error": "invalid_request"}`,
},
{
name: "invalid refresh token",
args: args{
r: &http.Request{
Method: "POST",
Header: http.Header{
http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"},
},
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("refresh_token=%s", "notatoken"))),
},
},
wantCode: http.StatusBadRequest,
wantBody: `{"error": "invalid_grant"}`,
},
{
name: "wrong client",
fields: fields{
clients: []*Client{
{
ClientID: "notclient",
ClientSecret: "secret",
},
},
signingKeys: map[int]*ecdsa.PrivateKey{
0: &badSigningKey,
1: testSigningKey,
},
},
args: args{
r: &http.Request{
Method: "POST",
Header: http.Header{
http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"},
},
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("refresh_token=%s", testRefreshTokenClientKID1MockSingingKey))),
},
},
wantCode: http.StatusUnauthorized,
wantBody: `{"error": "invalid_client"}`,
},
{
name: "missing authentication for confidential client",
fields: fields{
clients: []*Client{
{
ClientID: "client",
ClientSecret: "secret",
},
},
signingKeys: map[int]*ecdsa.PrivateKey{
0: &badSigningKey,
1: testSigningKey,
},
},
args: args{
r: &http.Request{
Method: "POST",
Header: http.Header{
http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"},
},
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("refresh_token=%s", testRefreshTokenClientKID1MockSingingKey))),
},
},
wantCode: http.StatusUnauthorized,
wantBody: `{"error": "invalid_client"}`,
},
{
name: "problem with JWT creation",
fields: fields{
clients: []*Client{
{
ClientID: "client",
ClientSecret: "",
},
},
signingKeys: map[int]*ecdsa.PrivateKey{
0: &badSigningKey,
1: testSigningKey,
},
},
args: args{
r: &http.Request{
Method: "POST",
Header: http.Header{
http.CanonicalHeaderKey("Content-Type"): []string{"application/x-www-form-urlencoded"},
},
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("refresh_token=%s", testRefreshTokenClientKID1MockSingingKey))),
},
},
wantCode: http.StatusInternalServerError,
wantBody: `error while creating JWT`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := &AuthorizationServer{
clients: tt.fields.clients,
signingKeys: tt.fields.signingKeys,
codes: tt.fields.codes,
}

rr := httptest.NewRecorder()
srv.doRefreshTokenFlow(rr, tt.args.r)

gotCode := rr.Code
if gotCode != tt.wantCode {
t.Errorf("AuthorizationServer.doRefreshTokenFlow() code = %v, wantCode %v", gotCode, tt.wantCode)
}

gotBody := strings.Trim(rr.Body.String(), "\n")
if gotBody != tt.wantBody {
t.Errorf("AuthorizationServer.doRefreshTokenFlow() body = %v, wantBody %v", gotBody, tt.wantBody)
}
})
}
}

func TestAuthorizationServer_ValidateCode(t *testing.T) {
type fields struct {
clients []*Client
Expand Down Expand Up @@ -670,7 +837,7 @@ func TestAuthorizationServer_GenerateToken(t *testing.T) {
name: "invalid key ID",
fields: fields{
signingKeys: map[int]*ecdsa.PrivateKey{
0: &mockSigningKey,
0: testSigningKey,
},
},
args: args{
Expand All @@ -684,7 +851,7 @@ func TestAuthorizationServer_GenerateToken(t *testing.T) {
name: "bad refresh key",
fields: fields{
signingKeys: map[int]*ecdsa.PrivateKey{
0: &mockSigningKey,
0: testSigningKey,
1: &badSigningKey,
},
},
Expand All @@ -700,7 +867,7 @@ func TestAuthorizationServer_GenerateToken(t *testing.T) {
name: "invalid refresh key ID",
fields: fields{
signingKeys: map[int]*ecdsa.PrivateKey{
0: &mockSigningKey,
0: testSigningKey,
},
},
args: args{
Expand Down Expand Up @@ -750,15 +917,15 @@ func TestNewServer(t *testing.T) {
opts: []AuthorizationServerOption{
WithSigningKeysFunc(func() (keys map[int]*ecdsa.PrivateKey) {
return map[int]*ecdsa.PrivateKey{
0: &mockSigningKey,
0: testSigningKey,
}
})},
},
want: &AuthorizationServer{
clients: []*Client{},
codes: map[string]*codeInfo{},
signingKeys: map[int]*ecdsa.PrivateKey{
0: &mockSigningKey,
0: testSigningKey,
},
},
},
Expand Down

0 comments on commit b135c89

Please sign in to comment.