Skip to content

Commit

Permalink
refactor: simplify logic
Browse files Browse the repository at this point in the history
  • Loading branch information
james-d-elliott committed Jul 20, 2023
1 parent deb3be3 commit 6db02e8
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 85 deletions.
30 changes: 11 additions & 19 deletions access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ type AccessRequest struct {
GrantTypes Arguments `json:"grantTypes" gorethink:"grantTypes"`
HandledGrantType Arguments `json:"handledGrantType" gorethink:"handledGrantType"`

RefreshTokenRequestedScope Arguments
RefreshTokenGrantedScope Arguments

Request
}

Expand All @@ -27,26 +24,21 @@ func (a *AccessRequest) GetGrantTypes() Arguments {
return a.GrantTypes
}

func (a *AccessRequest) GetRefreshTokenRequestedScopes() (scopes Arguments) {
if a.RefreshTokenRequestedScope == nil {
return a.RequestedScope
}

return a.RefreshTokenRequestedScope
func (a *AccessRequest) SetGrantedScopes(scopes Arguments) {
a.GrantedScope = scopes
}

func (a *AccessRequest) SetRefreshTokenRequestedScopes(scopes Arguments) {
a.RefreshTokenRequestedScope = scopes
}
func (a *AccessRequest) SanitizeRestoreRefreshTokenOriginalRequester(requester Requester) Requester {
r := a.Sanitize(nil).(*Request)

func (a *AccessRequest) GetRefreshTokenGrantedScopes() (scopes Arguments) {
if a.RefreshTokenGrantedScope == nil {
return a.GrantedScope
ar := &AccessRequest{
Request: *r,
}

return a.RefreshTokenGrantedScope
}
ar.SetID(requester.GetID())

ar.SetRequestedScopes(requester.GetRequestedScopes())
ar.SetGrantedScopes(requester.GetGrantedScopes())

func (a *AccessRequest) SetRefreshTokenGrantedScopes(scopes Arguments) {
a.RefreshTokenGrantedScope = scopes
return ar
}
21 changes: 6 additions & 15 deletions handler/oauth2/flow_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,6 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex
request.SetRequestedScopes(fosite.RemoveEmpty(strings.Split(scope, " ")))
}

// If a new refresh token is issued, the refresh token scope MUST be identical to that of the refresh token included
// by the client in the request.
if rtRequest, ok := request.(fosite.RefreshTokenAccessRequester); ok {
rtRequest.SetRefreshTokenRequestedScopes(originalRequest.GetRequestedScopes())
rtRequest.SetRefreshTokenGrantedScopes(originalRequest.GetGrantedScopes())
}

request.SetRequestedAudience(originalRequest.GetRequestedAudience())

strategy := c.Config.GetScopeStrategy(ctx)
Expand Down Expand Up @@ -167,30 +160,28 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con
err = c.handleRefreshTokenEndpointStorageError(ctx, err)
}()

ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
originalRequest, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
if err != nil {
return err
} else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil {
} else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, originalRequest.GetID()); err != nil {
return err
}

if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, ts.GetID(), signature); err != nil {
if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, originalRequest.GetID(), signature); err != nil {
return err
}

storeReq := requester.Sanitize([]string{})
storeReq.SetID(ts.GetID())
storeReq.SetID(originalRequest.GetID())

if err = c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil {
return err
}

if rtRequest, ok := requester.(fosite.RefreshTokenAccessRequester); ok {
rtStoreReq := requester.Sanitize([]string{}).(*fosite.Request)
rtStoreReq.SetID(ts.GetID())
rtStoreReq := rtRequest.SanitizeRestoreRefreshTokenOriginalRequester(originalRequest)

rtStoreReq.RequestedScope = rtRequest.GetRefreshTokenRequestedScopes()
rtStoreReq.GrantedScope = rtRequest.GetRefreshTokenGrantedScopes()
rtStoreReq.SetSession(requester.GetSession().Clone())

if err = c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, rtStoreReq); err != nil {
return err
Expand Down
174 changes: 134 additions & 40 deletions integration/refresh_token_grant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package integration_test

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
Expand All @@ -12,15 +13,14 @@ import (
"testing"
"time"

"github.com/ory/fosite/internal/gen"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"

"github.com/ory/fosite"
"github.com/ory/fosite/compose"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/internal/gen"
"github.com/ory/fosite/token/jwt"
)

Expand Down Expand Up @@ -266,7 +266,9 @@ func TestRefreshTokenFlow(t *testing.T) {
}
}

func TestRefreshTokenFlowScopeNarrowing(t *testing.T) {
func TestRefreshTokenFlowScopeParameter(t *testing.T) {
ctx := context.Background()

session := &defaultSession{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Expand All @@ -278,17 +280,16 @@ func TestRefreshTokenFlowScopeNarrowing(t *testing.T) {
},
}
fc := new(fosite.Config)
fc.RefreshTokenLifespan = -1
fc.GlobalSecret = []byte("some-secret-thats-random-some-secret-thats-random-")
f := compose.ComposeAllEnabled(fc, fositeStore, gen.MustRSAKey())
ts := mockServer(t, f, session)
defer ts.Close()

fc.ScopeStrategy = fosite.ExactScopeStrategy

oauthClient := newOAuth2Client(ts)
oauthClient.Scopes = []string{"openid", "offline", "offline_access", "foo", "bar"}
oauthClient.ClientID = "grant-all-requested-scopes-client"
client := newOAuth2Client(ts)
client.Scopes = []string{"openid", "offline", "offline_access", "foo", "bar"}
client.ClientID = "grant-all-requested-scopes-client"

state := "1234567890"

Expand All @@ -304,53 +305,146 @@ func TestRefreshTokenFlowScopeNarrowing(t *testing.T) {

fositeStore.Clients["grant-all-requested-scopes-client"] = testRefreshingClient

resp, err := http.Get(oauthClient.AuthCodeURL(state))
s := compose.NewOAuth2HMACStrategy(fc)

originalScopes := fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar"}

testCases := []struct {
name string
scopes fosite.Arguments
expected fosite.Arguments
err string
}{
{
"ShouldGrantOriginalScopesWhenOmitted",
nil,
originalScopes,
"",
},
{
"ShouldNarrowScopesWhenIncluded",
fosite.Arguments{"openid", "offline_access", "foo"},
fosite.Arguments{"openid", "offline_access", "foo"},
"",
},
{
"ShouldGrantOriginalScopesWhenOmittedAfterNarrowing",
nil,
originalScopes,
"",
},
{
"ShouldGrantOriginalScopesExplicitlyRequested",
originalScopes,
originalScopes,
"",
},
{
"ShouldErrorWhenBroadeningScopesAllowedByClientButNotOriginallyGranted",
fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar", "baz"},
nil,
"The requested scope is invalid, unknown, or malformed. The requested scope 'baz' was not originally granted by the resource owner.",
},
}

type step struct {
OAuth2 *oauth2.Token
SessionAT, SessionRT fosite.Requester
}

entries := make([]step, len(testCases)+1)

resp, err := http.Get(client.AuthCodeURL(state))
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

token, err := oauthClient.Exchange(oauth2.NoContext, resp.Request.URL.Query().Get("code"), oauth2.SetAuthURLParam("client_id", oauthClient.ClientID))
entries[0].OAuth2, err = client.Exchange(ctx, resp.Request.URL.Query().Get("code"), oauth2.SetAuthURLParam("client_id", client.ClientID))

require.NoError(t, err)
require.NotEmpty(t, token.AccessToken)
require.NotEmpty(t, token.RefreshToken)
require.NotEmpty(t, entries[0].OAuth2.AccessToken)
require.NotEmpty(t, entries[0].OAuth2.RefreshToken)

assert.Equal(t, "openid offline offline_access foo bar", token.Extra("scope"))
assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope"))

token1Refresh, err := doRefresh(oauthClient, token, nil)
entries[0].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[0].OAuth2.AccessToken), nil)
require.NoError(t, err)
require.NotEmpty(t, token1Refresh.AccessToken)
require.NotEmpty(t, token1Refresh.RefreshToken)

assert.Equal(t, "openid offline offline_access foo bar", token1Refresh.Extra("scope"))

token2Refresh, err := doRefresh(oauthClient, token1Refresh, []string{"openid", "offline_access", "foo"})
entries[0].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[0].OAuth2.RefreshToken), nil)
require.NoError(t, err)
require.NotEmpty(t, token2Refresh.AccessToken)
require.NotEmpty(t, token2Refresh.RefreshToken)

assert.Equal(t, "openid offline_access foo", token2Refresh.Extra("scope"))
assert.ElementsMatch(t, entries[0].SessionAT.GetRequestedScopes(), originalScopes)
assert.ElementsMatch(t, entries[0].SessionRT.GetRequestedScopes(), originalScopes)
assert.ElementsMatch(t, entries[0].SessionAT.GetGrantedScopes(), originalScopes)
assert.ElementsMatch(t, entries[0].SessionRT.GetGrantedScopes(), originalScopes)
assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope"))

token3Refresh, err := doRefresh(oauthClient, token2Refresh, []string{"openid", "offline", "offline_access", "foo", "bar"})
require.NoError(t, err)
require.NotEmpty(t, token3Refresh.AccessToken)
require.NotEmpty(t, token3Refresh.RefreshToken)
for i, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
time.Sleep(time.Second)

assert.Equal(t, "openid offline offline_access foo bar", token3Refresh.Extra("scope"))
idx := i + 1

token4Refresh, err := doRefresh(oauthClient, token3Refresh, []string{"openid", "offline", "offline_access", "foo", "bar", "baz"})
require.Error(t, err)
require.Nil(t, token4Refresh)
require.Contains(t, err.Error(), "The requested scope is invalid, unknown, or malformed. The requested scope 'baz' was not originally granted by the resource owner.")
}
opts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("refresh_token", entries[i].OAuth2.RefreshToken),
oauth2.SetAuthURLParam("grant_type", "refresh_token"),
}

func doRefresh(client *oauth2.Config, t *oauth2.Token, scopes []string) (token *oauth2.Token, err error) {
opts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("refresh_token", t.RefreshToken),
oauth2.SetAuthURLParam("grant_type", "refresh_token"),
}
if len(tc.scopes) != 0 {
opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(tc.scopes, " ")), oauth2.SetAuthURLParam("client_id", client.ClientID))
}

if len(scopes) != 0 {
opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(scopes, " ")), oauth2.SetAuthURLParam("client_id", client.ClientID))
}
entries[idx].OAuth2, err = client.Exchange(ctx, "", opts...)
if len(tc.err) != 0 {
require.Error(t, err)
require.Nil(t, entries[idx].OAuth2)
require.Contains(t, err.Error(), tc.err)

return
}

return client.Exchange(oauth2.NoContext, "", opts...)
require.NoError(t, err)
require.NotEmpty(t, entries[idx].OAuth2.AccessToken)
require.NotEmpty(t, entries[idx].OAuth2.RefreshToken)

entries[idx].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[idx].OAuth2.AccessToken), nil)
require.NoError(t, err)

entries[idx].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[idx].OAuth2.RefreshToken), nil)
require.NoError(t, err)

if len(tc.scopes) != 0 {
assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), tc.scopes)
assert.Equal(t, strings.Join(tc.expected, " "), entries[idx].OAuth2.Extra("scope"))
} else {
assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), originalScopes)
assert.Equal(t, strings.Join(originalScopes, " "), entries[idx].OAuth2.Extra("scope"))
}
assert.ElementsMatch(t, entries[idx].SessionAT.GetGrantedScopes(), tc.expected)
assert.ElementsMatch(t, entries[idx].SessionRT.GetRequestedScopes(), originalScopes)
assert.ElementsMatch(t, entries[idx].SessionRT.GetGrantedScopes(), originalScopes)

var (
j int
entry step
)

assert.Equal(t, entries[idx].SessionAT.GetID(), entries[idx].SessionRT.GetID())

for j, entry = range entries {
if j == idx {
break
}

assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionAT.GetID())
assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionRT.GetID())
assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionAT.GetID())
assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionRT.GetID())

assert.Greater(t, entries[idx].SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix(), entry.SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix())
assert.Greater(t, entries[idx].SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix(), entry.SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix())
assert.Greater(t, entries[idx].SessionAT.GetRequestedAt().Unix(), entry.SessionAT.GetRequestedAt().Unix())
assert.Greater(t, entries[idx].SessionRT.GetRequestedAt().Unix(), entry.SessionRT.GetRequestedAt().Unix())
}
})
}
}
14 changes: 3 additions & 11 deletions oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,9 @@ type Requester interface {
// RefreshTokenAccessRequester is an extended AccessRequester implementation that allows preserving
// the original Requester.
type RefreshTokenAccessRequester interface {
// GetRefreshTokenRequestedScopes returns the request's scopes specifically for the refresh token.
GetRefreshTokenRequestedScopes() (scopes Arguments)

// SetRefreshTokenRequestedScopes sets the request's scopes specifically for the refresh token.
SetRefreshTokenRequestedScopes(scopes Arguments)

// GetRefreshTokenGrantedScopes returns all granted scopes specifically for the refresh token.
GetRefreshTokenGrantedScopes() (scopes Arguments)

// SetRefreshTokenGrantedScopes sets all granted scopes specifically for the refresh token.
SetRefreshTokenGrantedScopes(scopes Arguments)
// SanitizeRestoreRefreshTokenOriginalRequester returns a sanitized copy of this Requester and mutates the relevant
// values from the provided Requester which is the original refresh token session Requester.
SanitizeRestoreRefreshTokenOriginalRequester(requester Requester) Requester

AccessRequester
}
Expand Down

0 comments on commit 6db02e8

Please sign in to comment.