diff --git a/access_request.go b/access_request.go index 6fc21680a..4f1583757 100644 --- a/access_request.go +++ b/access_request.go @@ -7,9 +7,6 @@ type AccessRequest struct { GrantTypes Arguments `json:"grantTypes" gorethink:"grantTypes"` HandledGrantType Arguments `json:"handledGrantType" gorethink:"handledGrantType"` - RefreshTokenRequestedScope Arguments - RefreshTokenGrantedScope Arguments - Request } @@ -27,26 +24,21 @@ func (a *AccessRequest) GetGrantTypes() Arguments { return a.GrantTypes } -func (a *AccessRequest) GetRefreshTokenRequestedScopes() (scopes Arguments) { - if a.RefreshTokenRequestedScope == nil { - return a.RequestedScope - } - - return a.RefreshTokenRequestedScope +func (a *AccessRequest) SetGrantedScopes(scopes Arguments) { + a.GrantedScope = scopes } -func (a *AccessRequest) SetRefreshTokenRequestedScopes(scopes Arguments) { - a.RefreshTokenRequestedScope = scopes -} +func (a *AccessRequest) SanitizeRestoreRefreshTokenOriginalRequester(requester Requester) Requester { + r := a.Sanitize(nil).(*Request) -func (a *AccessRequest) GetRefreshTokenGrantedScopes() (scopes Arguments) { - if a.RefreshTokenGrantedScope == nil { - return a.GrantedScope + ar := &AccessRequest{ + Request: *r, } - return a.RefreshTokenGrantedScope -} + ar.SetID(requester.GetID()) + + ar.SetRequestedScopes(requester.GetRequestedScopes()) + ar.SetGrantedScopes(requester.GetGrantedScopes()) -func (a *AccessRequest) SetRefreshTokenGrantedScopes(scopes Arguments) { - a.RefreshTokenGrantedScope = scopes + return ar } diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index c520863bc..659c4ea34 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -97,13 +97,6 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex request.SetRequestedScopes(fosite.RemoveEmpty(strings.Split(scope, " "))) } - // If a new refresh token is issued, the refresh token scope MUST be identical to that of the refresh token included - // by the client in the request. - if rtRequest, ok := request.(fosite.RefreshTokenAccessRequester); ok { - rtRequest.SetRefreshTokenRequestedScopes(originalRequest.GetRequestedScopes()) - rtRequest.SetRefreshTokenGrantedScopes(originalRequest.GetGrantedScopes()) - } - request.SetRequestedAudience(originalRequest.GetRequestedAudience()) strategy := c.Config.GetScopeStrategy(ctx) @@ -167,30 +160,28 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con err = c.handleRefreshTokenEndpointStorageError(ctx, err) }() - ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil) + originalRequest, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil) if err != nil { return err - } else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil { + } else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, originalRequest.GetID()); err != nil { return err } - if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, ts.GetID(), signature); err != nil { + if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, originalRequest.GetID(), signature); err != nil { return err } storeReq := requester.Sanitize([]string{}) - storeReq.SetID(ts.GetID()) + storeReq.SetID(originalRequest.GetID()) if err = c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil { return err } if rtRequest, ok := requester.(fosite.RefreshTokenAccessRequester); ok { - rtStoreReq := requester.Sanitize([]string{}).(*fosite.Request) - rtStoreReq.SetID(ts.GetID()) + rtStoreReq := rtRequest.SanitizeRestoreRefreshTokenOriginalRequester(originalRequest) - rtStoreReq.RequestedScope = rtRequest.GetRefreshTokenRequestedScopes() - rtStoreReq.GrantedScope = rtRequest.GetRefreshTokenGrantedScopes() + rtStoreReq.SetSession(requester.GetSession().Clone()) if err = c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, rtStoreReq); err != nil { return err diff --git a/integration/refresh_token_grant_test.go b/integration/refresh_token_grant_test.go index 7e3693f39..c91c8be61 100644 --- a/integration/refresh_token_grant_test.go +++ b/integration/refresh_token_grant_test.go @@ -4,6 +4,7 @@ package integration_test import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -12,8 +13,6 @@ import ( "testing" "time" - "github.com/ory/fosite/internal/gen" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -21,6 +20,7 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/compose" "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/internal/gen" "github.com/ory/fosite/token/jwt" ) @@ -266,7 +266,9 @@ func TestRefreshTokenFlow(t *testing.T) { } } -func TestRefreshTokenFlowScopeNarrowing(t *testing.T) { +func TestRefreshTokenFlowScopeParameter(t *testing.T) { + ctx := context.Background() + session := &defaultSession{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ @@ -278,7 +280,6 @@ func TestRefreshTokenFlowScopeNarrowing(t *testing.T) { }, } fc := new(fosite.Config) - fc.RefreshTokenLifespan = -1 fc.GlobalSecret = []byte("some-secret-thats-random-some-secret-thats-random-") f := compose.ComposeAllEnabled(fc, fositeStore, gen.MustRSAKey()) ts := mockServer(t, f, session) @@ -286,9 +287,9 @@ func TestRefreshTokenFlowScopeNarrowing(t *testing.T) { fc.ScopeStrategy = fosite.ExactScopeStrategy - oauthClient := newOAuth2Client(ts) - oauthClient.Scopes = []string{"openid", "offline", "offline_access", "foo", "bar"} - oauthClient.ClientID = "grant-all-requested-scopes-client" + client := newOAuth2Client(ts) + client.Scopes = []string{"openid", "offline", "offline_access", "foo", "bar"} + client.ClientID = "grant-all-requested-scopes-client" state := "1234567890" @@ -304,53 +305,146 @@ func TestRefreshTokenFlowScopeNarrowing(t *testing.T) { fositeStore.Clients["grant-all-requested-scopes-client"] = testRefreshingClient - resp, err := http.Get(oauthClient.AuthCodeURL(state)) + s := compose.NewOAuth2HMACStrategy(fc) + + originalScopes := fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar"} + + testCases := []struct { + name string + scopes fosite.Arguments + expected fosite.Arguments + err string + }{ + { + "ShouldGrantOriginalScopesWhenOmitted", + nil, + originalScopes, + "", + }, + { + "ShouldNarrowScopesWhenIncluded", + fosite.Arguments{"openid", "offline_access", "foo"}, + fosite.Arguments{"openid", "offline_access", "foo"}, + "", + }, + { + "ShouldGrantOriginalScopesWhenOmittedAfterNarrowing", + nil, + originalScopes, + "", + }, + { + "ShouldGrantOriginalScopesExplicitlyRequested", + originalScopes, + originalScopes, + "", + }, + { + "ShouldErrorWhenBroadeningScopesAllowedByClientButNotOriginallyGranted", + fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar", "baz"}, + nil, + "The requested scope is invalid, unknown, or malformed. The requested scope 'baz' was not originally granted by the resource owner.", + }, + } + + type step struct { + OAuth2 *oauth2.Token + SessionAT, SessionRT fosite.Requester + } + + entries := make([]step, len(testCases)+1) + + resp, err := http.Get(client.AuthCodeURL(state)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - token, err := oauthClient.Exchange(oauth2.NoContext, resp.Request.URL.Query().Get("code"), oauth2.SetAuthURLParam("client_id", oauthClient.ClientID)) + entries[0].OAuth2, err = client.Exchange(ctx, resp.Request.URL.Query().Get("code"), oauth2.SetAuthURLParam("client_id", client.ClientID)) + require.NoError(t, err) - require.NotEmpty(t, token.AccessToken) - require.NotEmpty(t, token.RefreshToken) + require.NotEmpty(t, entries[0].OAuth2.AccessToken) + require.NotEmpty(t, entries[0].OAuth2.RefreshToken) - assert.Equal(t, "openid offline offline_access foo bar", token.Extra("scope")) + assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope")) - token1Refresh, err := doRefresh(oauthClient, token, nil) + entries[0].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[0].OAuth2.AccessToken), nil) require.NoError(t, err) - require.NotEmpty(t, token1Refresh.AccessToken) - require.NotEmpty(t, token1Refresh.RefreshToken) - - assert.Equal(t, "openid offline offline_access foo bar", token1Refresh.Extra("scope")) - token2Refresh, err := doRefresh(oauthClient, token1Refresh, []string{"openid", "offline_access", "foo"}) + entries[0].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[0].OAuth2.RefreshToken), nil) require.NoError(t, err) - require.NotEmpty(t, token2Refresh.AccessToken) - require.NotEmpty(t, token2Refresh.RefreshToken) - assert.Equal(t, "openid offline_access foo", token2Refresh.Extra("scope")) + assert.ElementsMatch(t, entries[0].SessionAT.GetRequestedScopes(), originalScopes) + assert.ElementsMatch(t, entries[0].SessionRT.GetRequestedScopes(), originalScopes) + assert.ElementsMatch(t, entries[0].SessionAT.GetGrantedScopes(), originalScopes) + assert.ElementsMatch(t, entries[0].SessionRT.GetGrantedScopes(), originalScopes) + assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope")) - token3Refresh, err := doRefresh(oauthClient, token2Refresh, []string{"openid", "offline", "offline_access", "foo", "bar"}) - require.NoError(t, err) - require.NotEmpty(t, token3Refresh.AccessToken) - require.NotEmpty(t, token3Refresh.RefreshToken) + for i, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + time.Sleep(time.Second) - assert.Equal(t, "openid offline offline_access foo bar", token3Refresh.Extra("scope")) + idx := i + 1 - token4Refresh, err := doRefresh(oauthClient, token3Refresh, []string{"openid", "offline", "offline_access", "foo", "bar", "baz"}) - require.Error(t, err) - require.Nil(t, token4Refresh) - require.Contains(t, err.Error(), "The requested scope is invalid, unknown, or malformed. The requested scope 'baz' was not originally granted by the resource owner.") -} + opts := []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("refresh_token", entries[i].OAuth2.RefreshToken), + oauth2.SetAuthURLParam("grant_type", "refresh_token"), + } -func doRefresh(client *oauth2.Config, t *oauth2.Token, scopes []string) (token *oauth2.Token, err error) { - opts := []oauth2.AuthCodeOption{ - oauth2.SetAuthURLParam("refresh_token", t.RefreshToken), - oauth2.SetAuthURLParam("grant_type", "refresh_token"), - } + if len(tc.scopes) != 0 { + opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(tc.scopes, " ")), oauth2.SetAuthURLParam("client_id", client.ClientID)) + } - if len(scopes) != 0 { - opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(scopes, " ")), oauth2.SetAuthURLParam("client_id", client.ClientID)) - } + entries[idx].OAuth2, err = client.Exchange(ctx, "", opts...) + if len(tc.err) != 0 { + require.Error(t, err) + require.Nil(t, entries[idx].OAuth2) + require.Contains(t, err.Error(), tc.err) + + return + } - return client.Exchange(oauth2.NoContext, "", opts...) + require.NoError(t, err) + require.NotEmpty(t, entries[idx].OAuth2.AccessToken) + require.NotEmpty(t, entries[idx].OAuth2.RefreshToken) + + entries[idx].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[idx].OAuth2.AccessToken), nil) + require.NoError(t, err) + + entries[idx].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[idx].OAuth2.RefreshToken), nil) + require.NoError(t, err) + + if len(tc.scopes) != 0 { + assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), tc.scopes) + assert.Equal(t, strings.Join(tc.expected, " "), entries[idx].OAuth2.Extra("scope")) + } else { + assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), originalScopes) + assert.Equal(t, strings.Join(originalScopes, " "), entries[idx].OAuth2.Extra("scope")) + } + assert.ElementsMatch(t, entries[idx].SessionAT.GetGrantedScopes(), tc.expected) + assert.ElementsMatch(t, entries[idx].SessionRT.GetRequestedScopes(), originalScopes) + assert.ElementsMatch(t, entries[idx].SessionRT.GetGrantedScopes(), originalScopes) + + var ( + j int + entry step + ) + + assert.Equal(t, entries[idx].SessionAT.GetID(), entries[idx].SessionRT.GetID()) + + for j, entry = range entries { + if j == idx { + break + } + + assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionAT.GetID()) + assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionRT.GetID()) + assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionAT.GetID()) + assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionRT.GetID()) + + assert.Greater(t, entries[idx].SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix(), entry.SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix()) + assert.Greater(t, entries[idx].SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix(), entry.SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix()) + assert.Greater(t, entries[idx].SessionAT.GetRequestedAt().Unix(), entry.SessionAT.GetRequestedAt().Unix()) + assert.Greater(t, entries[idx].SessionRT.GetRequestedAt().Unix(), entry.SessionRT.GetRequestedAt().Unix()) + } + }) + } } diff --git a/oauth2.go b/oauth2.go index 65b6227d0..fcd0164d1 100644 --- a/oauth2.go +++ b/oauth2.go @@ -248,17 +248,9 @@ type Requester interface { // RefreshTokenAccessRequester is an extended AccessRequester implementation that allows preserving // the original Requester. type RefreshTokenAccessRequester interface { - // GetRefreshTokenRequestedScopes returns the request's scopes specifically for the refresh token. - GetRefreshTokenRequestedScopes() (scopes Arguments) - - // SetRefreshTokenRequestedScopes sets the request's scopes specifically for the refresh token. - SetRefreshTokenRequestedScopes(scopes Arguments) - - // GetRefreshTokenGrantedScopes returns all granted scopes specifically for the refresh token. - GetRefreshTokenGrantedScopes() (scopes Arguments) - - // SetRefreshTokenGrantedScopes sets all granted scopes specifically for the refresh token. - SetRefreshTokenGrantedScopes(scopes Arguments) + // SanitizeRestoreRefreshTokenOriginalRequester returns a sanitized copy of this Requester and mutates the relevant + // values from the provided Requester which is the original refresh token session Requester. + SanitizeRestoreRefreshTokenOriginalRequester(requester Requester) Requester AccessRequester }