Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth2 Token refresh implemented #328

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ examples/*.crt

# Vscode
.vscode/

# Pycharm venv
.venv
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#
# This file is part of Invenio.
# Copyright (C) 2016-2018 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Add expires_at and refresh_token to remote token."""

import sqlalchemy as sa
import sqlalchemy_utils
from alembic import op

# revision identifiers, used by Alembic.
revision = "7def990b852e"
down_revision = "aaa265b0afa6"
branch_labels = ()
depends_on = ("aaa265b0afa6",)


def upgrade():
"""Upgrade database."""
op.add_column(
"oauthclient_remotetoken",
sa.Column("refresh_token", sqlalchemy_utils.EncryptedType(), nullable=True),
)
op.add_column(
"oauthclient_remotetoken", sa.Column("expires_at", sa.DateTime(), nullable=True)
)


def downgrade():
"""Downgrade database."""
op.drop_column("oauthclient_remotetoken", "expires_at")
op.drop_column("oauthclient_remotetoken", "refresh_token")
9 changes: 9 additions & 0 deletions invenio_oauthclient/alembic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2016-2018 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Alembic migrations for Invenio-OAuthClient."""
4 changes: 2 additions & 2 deletions invenio_oauthclient/contrib/keycloak/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
app_key=None,
icon=None,
scopes="openid",
**kwargs
**kwargs,
):
"""The constructor takes two arguments.

Expand All @@ -64,7 +64,7 @@ def __init__(
request_token_params={"scope": scopes},
access_token_url=access_token_url,
authorize_url=authorize_url,
**kwargs
**kwargs,
)

self._handlers = dict(
Expand Down
9 changes: 8 additions & 1 deletion invenio_oauthclient/handlers/authorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,14 @@ def extra_signup_handler(remote, form, *args, **kwargs):
user = _register_user(response, remote, account_info, form)

# Link account and set session data
token = token_setter(remote, oauth_token[0], secret=oauth_token[1], user=user)
token = token_setter(
remote,
oauth_token[0],
secret=oauth_token[1],
user=user,
refresh_token=oauth_token[2] if len(oauth_token) > 2 else None,
expires_at=oauth_token[3] if len(oauth_token) > 3 else None,
)
if token is None:
raise OAuthClientTokenNotSet()

Expand Down
55 changes: 55 additions & 0 deletions invenio_oauthclient/handlers/refresh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2024 CESNET z.s.p.o.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Handler for refreshing access token."""

from flask_oauthlib.client import OAuthResponse
from flask_oauthlib.utils import to_bytes

from invenio_oauthclient.handlers.token import make_expiration_time

from ..models import RemoteToken
from ..proxies import current_oauthclient


def refresh_access_token(token: RemoteToken):
"""
Internal method to refresh the access token.

:param token: the remote token to be refreshed
:returns tuple of (access_token, secret, refresh_token, expires_at)

Note: the current access/refresh token are invalidated during this call
"""
remote_account = token.remote_account
client_id = remote_account.client_id
remote = next(
x
for x in current_oauthclient.oauth.remote_apps.values()
if x.consumer_key == client_id
)
client = remote.make_client()
refresh_token_request = client.prepare_refresh_token_request(
remote.access_token_url,
refresh_token=token.refresh_token,
client_id=remote.consumer_key,
client_secret=remote.consumer_secret,
)
resp, content = remote.http_request(
refresh_token_request[0],
refresh_token_request[1],
data=to_bytes(refresh_token_request[2], remote.encoding),
method="POST",
)
resp = OAuthResponse(resp, content, remote.content_type)
return (
resp.data.get("access_token"),
"",
resp.data.get("refresh_token"),
make_expiration_time(resp.data.get("expires_in")),
)
53 changes: 47 additions & 6 deletions invenio_oauthclient/handlers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# under the terms of the MIT License; see LICENSE file for more details.

"""Funcs to manage tokens."""

from datetime import datetime, timedelta
from functools import partial

from flask import current_app, session
Expand Down Expand Up @@ -112,10 +112,31 @@ def oauth2_token_setter(remote, resp, token_type="", extra_data=None):
secret="",
token_type=token_type,
extra_data=extra_data,
refresh_token=resp.get("refresh_token"),
expires_at=make_expiration_time(resp.get("expires_in")),
)


def token_setter(remote, token, secret="", token_type="", extra_data=None, user=None):
def make_expiration_time(expires_in):
"""Make expiration time from expires_in.

:param expires_in: The time in seconds.
"""
if expires_in is None:
return None
return datetime.utcnow() + timedelta(seconds=expires_in)


def token_setter(
remote,
token,
secret="",
token_type="",
extra_data=None,
user=None,
refresh_token=None,
expires_at=None,
):
"""Set token for user.

:param remote: The remote application.
Expand All @@ -127,7 +148,12 @@ def token_setter(remote, token, secret="", token_type="", extra_data=None, user=
:returns: A :class:`invenio_oauthclient.models.RemoteToken` instance or
``None``.
"""
session[token_session_key(remote.name)] = (token, secret)
session[token_session_key(remote.name)] = (
token,
secret,
refresh_token,
expires_at.isoformat() if expires_at else None,
)
user = user or current_user

# Save token if user is not anonymous (user exists but can be not active at
Expand All @@ -140,10 +166,17 @@ def token_setter(remote, token, secret="", token_type="", extra_data=None, user=
t = RemoteToken.get(uid, cid, token_type=token_type)

if t:
t.update_token(token, secret)
t.update_token(token, secret, refresh_token, expires_at)
else:
t = RemoteToken.create(
uid, cid, token, secret, token_type=token_type, extra_data=extra_data
uid,
cid,
token,
secret,
token_type=token_type,
extra_data=extra_data,
refresh_token=refresh_token,
expires_at=expires_at,
)
return t
return None
Expand Down Expand Up @@ -176,7 +209,15 @@ def token_getter(remote, token=""):
# Store token and secret in session
session[session_key] = remote_token.token()

return session.get(session_key, None)
ret = session.get(session_key, None)
if ret:
if len(ret) == 2:
# no refresh token nor expiration time
return ret[0], ret[1], None, None
if ret[3] is not None:
# refresh token and expiration time
return ret[0], ret[1], ret[2], datetime.fromisoformat(ret[3])
return ret


def token_delete(remote, token=""):
Expand Down
59 changes: 56 additions & 3 deletions invenio_oauthclient/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

"""Models for storing access tokens and links between users and remote apps."""

from datetime import datetime, timedelta

from flask import current_app

# UserIdentity imported for backward compatibility. UserIdentity was originally
Expand Down Expand Up @@ -119,6 +121,14 @@ class RemoteToken(db.Model, Timestamp):
)
"""Access token to remote application."""

refresh_token = db.Column(
EncryptedType(type_in=db.Text, key=_secret_key), nullable=True
)
"""Refresh token to remote application."""

expires_at = db.Column(db.DateTime, nullable=True)
"""Access token expiration date."""

secret = db.Column(db.Text(), default="", nullable=False)
"""Used only by OAuth 1."""

Expand All @@ -130,6 +140,16 @@ class RemoteToken(db.Model, Timestamp):
)
"""SQLAlchemy relationship to RemoteAccount objects."""

@property
def is_expired(self):
"""Check if access token has expired."""
if not self.expires_at:
return False

leeway = current_app.config.get("OAUTHCLIENT_TOKEN_EXPIRES_LEEWAY", 10)
expiration_with_leeway = self.expires_at - timedelta(seconds=leeway)
return expiration_with_leeway < datetime.utcnow()

def __repr__(self):
"""String representation for model."""
return (
Expand All @@ -141,18 +161,37 @@ def token(self):
"""Get token as expected by Flask-OAuthlib."""
return (self.access_token, self.secret)

def update_token(self, token, secret):
def update_token(self, token, secret, refresh_token=None, expires_at=None):
"""Update token with new values.

:param token: The token value.
:param secret: The secret key.
:param refresh_token: The refresh token
:param expires_at: Time when the access token expires
"""
if self.access_token != token or self.secret != secret:
if (
self.access_token != token
or self.secret != secret
or self.refresh_token != refresh_token
or self.expiration != expires_at
):
with db.session.begin_nested():
self.access_token = token
self.secret = secret
self.refresh_token = refresh_token
self.expires_at = expires_at
db.session.add(self)

def refresh_access_token(self):
"""Refresh the access token."""
if not self.refresh_token:
raise ValueError("No refresh token available")
from .handlers.refresh import refresh_access_token

access_token, refresh_token, secret, expires_at = refresh_access_token(self)
self.update_token(access_token, refresh_token, secret, expires_at)
db.session.commit()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question:
Should we use unit of work here? See: docs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make invenio-oauthclient dependent on invenio-records-resources (where the UnitOfWork is defined). I would personally stick with simple commit as the rest of the library uses that, but would document it in the pydoc to make sure that the caller knows that commit will occur - does it make sense?

Copy link
Member

@Samk13 Samk13 May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid point, but I could consider using db.session.begin_nested() to allow for rollbacks on errors and maintain consistency with the existing codebase.
Do you think this is a better approach?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely better, but I still have a small issue with that :) The problem is that when the oauth token endpoint is called, it invalidates the previous access & refresh token and returns a new pair (that is at least the case for our perun aai implementation). Then, if the new access & refresh is not stored to database (for example, the refresh method is invoked from an external begin_nested which rolls back afterwards for some reason), the original values stored in remote_token are completely unusable and any subsequent call will fail. That's why I would rather commit the token as soon as possible, or if using the begin_nested I would at least document that the caller should commit it as soon as possible - what would you prefer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, it will just introduce complexity regarding token management and state consistency.
Let's commit the token immediately after the update to ensure it's saved, avoiding potential issues with nested transactions and token invalidation. This keeps token management straightforward.
would you agree on this?


@classmethod
def get(cls, user_id, client_id, token_type="", access_token=None):
"""Get RemoteToken for user.
Expand Down Expand Up @@ -197,7 +236,17 @@ def get_by_token(cls, client_id, access_token, token_type=""):
)

@classmethod
def create(cls, user_id, client_id, token, secret, token_type="", extra_data=None):
def create(
cls,
user_id,
client_id,
token,
secret,
token_type="",
extra_data=None,
refresh_token=None,
expires_at=None,
):
"""Create a new access token.

.. note:: Creates RemoteAccount as well if it does not exists.
Expand All @@ -209,6 +258,8 @@ def create(cls, user_id, client_id, token, secret, token_type="", extra_data=Non
:param token_type: The token type. (Default: ``''``)
:param extra_data: Extra data to set in the remote account if the
remote account doesn't exists. (Default: ``None``)
:param refresh_token: The refresh token.
:param expires_at: Expiration of the token
:returns: A :class:`invenio_oauthclient.models.RemoteToken` instance.

"""
Expand All @@ -228,6 +279,8 @@ def create(cls, user_id, client_id, token, secret, token_type="", extra_data=Non
remote_account=account,
access_token=token,
secret=secret,
refresh_token=refresh_token,
expires_at=expires_at,
)
db.session.add(token)
return token
Expand Down
3 changes: 1 addition & 2 deletions invenio_oauthclient/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
"""Utility methods."""

from flask import current_app, request, session
from flask_login import current_user
from flask_principal import RoleNeed, UserNeed
from flask_principal import RoleNeed
from invenio_db.utils import rebuild_encrypted_properties
from itsdangerous import TimedJSONWebSignatureSerializer
from uritools import uricompose, urisplit
Expand Down
5 changes: 5 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def mock_remote_get(oauth, remote_app="test", data=None):
oauth.remote_apps[remote_app].get = MagicMock(return_value=data)


def mock_remote_http_request(oauth, remote_app="test", data=None):
"""Mock the oauth remote get response."""
oauth.remote_apps[remote_app].http_request = MagicMock(return_value=data)


def check_redirect_location(resp, loc):
"""Check response redirect location."""
assert resp._status_code == 302
Expand Down
2 changes: 1 addition & 1 deletion tests/test_base_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def test_token_getter(remote, models_fixture, app):
# Populated RemoteToken
RemoteToken.create(user.id, "testkey", "mytoken", "mysecret")
oauth_authenticate("dev", user)
assert token_getter(remote) == ("mytoken", "mysecret")
assert token_getter(remote) == ("mytoken", "mysecret", None, None)
Loading
Loading