diff --git a/.gitignore b/.gitignore index 77ddcb08..578cda97 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,6 @@ examples/*.crt # Vscode .vscode/ + +# Pycharm venv +.venv diff --git a/invenio_oauthclient/alembic/7def990b852e_add_expires_at_and_refresh_token_to_.py b/invenio_oauthclient/alembic/7def990b852e_add_expires_at_and_refresh_token_to_.py new file mode 100644 index 00000000..4aeb79e6 --- /dev/null +++ b/invenio_oauthclient/alembic/7def990b852e_add_expires_at_and_refresh_token_to_.py @@ -0,0 +1,35 @@ +# +# This file is part of Invenio. +# Copyright (C) 2016-2018 CERN. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Add expires_at and refresh_token to remote token.""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "7def990b852e" +down_revision = "aaa265b0afa6" +branch_labels = () +depends_on = ("aaa265b0afa6",) + + +def upgrade(): + """Upgrade database.""" + op.add_column( + "oauthclient_remotetoken", + sa.Column("refresh_token", sqlalchemy_utils.EncryptedType(), nullable=True), + ) + op.add_column( + "oauthclient_remotetoken", sa.Column("expires_at", sa.DateTime(), nullable=True) + ) + + +def downgrade(): + """Downgrade database.""" + op.drop_column("oauthclient_remotetoken", "expires_at") + op.drop_column("oauthclient_remotetoken", "refresh_token") diff --git a/invenio_oauthclient/alembic/__init__.py b/invenio_oauthclient/alembic/__init__.py new file mode 100644 index 00000000..d92f3a6f --- /dev/null +++ b/invenio_oauthclient/alembic/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +# +# This file is part of Invenio. +# Copyright (C) 2016-2018 CERN. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Alembic migrations for Invenio-OAuthClient.""" diff --git a/invenio_oauthclient/contrib/keycloak/settings.py b/invenio_oauthclient/contrib/keycloak/settings.py index 0fd666e0..9b492778 100644 --- a/invenio_oauthclient/contrib/keycloak/settings.py +++ b/invenio_oauthclient/contrib/keycloak/settings.py @@ -38,7 +38,7 @@ def __init__( app_key=None, icon=None, scopes="openid", - **kwargs + **kwargs, ): """The constructor takes two arguments. @@ -64,7 +64,7 @@ def __init__( request_token_params={"scope": scopes}, access_token_url=access_token_url, authorize_url=authorize_url, - **kwargs + **kwargs, ) self._handlers = dict( diff --git a/invenio_oauthclient/handlers/authorized.py b/invenio_oauthclient/handlers/authorized.py index ce6f795b..e1bc1753 100644 --- a/invenio_oauthclient/handlers/authorized.py +++ b/invenio_oauthclient/handlers/authorized.py @@ -211,7 +211,14 @@ def extra_signup_handler(remote, form, *args, **kwargs): user = _register_user(response, remote, account_info, form) # Link account and set session data - token = token_setter(remote, oauth_token[0], secret=oauth_token[1], user=user) + token = token_setter( + remote, + oauth_token[0], + secret=oauth_token[1], + user=user, + refresh_token=oauth_token[2] if len(oauth_token) > 2 else None, + expires_at=oauth_token[3] if len(oauth_token) > 3 else None, + ) if token is None: raise OAuthClientTokenNotSet() diff --git a/invenio_oauthclient/handlers/refresh.py b/invenio_oauthclient/handlers/refresh.py new file mode 100644 index 00000000..22b5a90e --- /dev/null +++ b/invenio_oauthclient/handlers/refresh.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# +# This file is part of Invenio. +# Copyright (C) 2024 CESNET z.s.p.o. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Handler for refreshing access token.""" + +from flask_oauthlib.client import OAuthResponse +from flask_oauthlib.utils import to_bytes + +from invenio_oauthclient.handlers.token import make_expiration_time + +from ..models import RemoteToken +from ..proxies import current_oauthclient + + +def refresh_access_token(token: RemoteToken): + """ + Internal method to refresh the access token. + + :param token: the remote token to be refreshed + :returns tuple of (access_token, secret, refresh_token, expires_at) + + Note: the current access/refresh token are invalidated during this call + """ + remote_account = token.remote_account + client_id = remote_account.client_id + remote = next( + x + for x in current_oauthclient.oauth.remote_apps.values() + if x.consumer_key == client_id + ) + client = remote.make_client() + refresh_token_request = client.prepare_refresh_token_request( + remote.access_token_url, + refresh_token=token.refresh_token, + client_id=remote.consumer_key, + client_secret=remote.consumer_secret, + ) + resp, content = remote.http_request( + refresh_token_request[0], + refresh_token_request[1], + data=to_bytes(refresh_token_request[2], remote.encoding), + method="POST", + ) + resp = OAuthResponse(resp, content, remote.content_type) + return ( + resp.data.get("access_token"), + "", + resp.data.get("refresh_token"), + make_expiration_time(resp.data.get("expires_in")), + ) diff --git a/invenio_oauthclient/handlers/token.py b/invenio_oauthclient/handlers/token.py index c632a841..e020fac9 100644 --- a/invenio_oauthclient/handlers/token.py +++ b/invenio_oauthclient/handlers/token.py @@ -7,7 +7,7 @@ # under the terms of the MIT License; see LICENSE file for more details. """Funcs to manage tokens.""" - +from datetime import datetime, timedelta from functools import partial from flask import current_app, session @@ -112,10 +112,31 @@ def oauth2_token_setter(remote, resp, token_type="", extra_data=None): secret="", token_type=token_type, extra_data=extra_data, + refresh_token=resp.get("refresh_token"), + expires_at=make_expiration_time(resp.get("expires_in")), ) -def token_setter(remote, token, secret="", token_type="", extra_data=None, user=None): +def make_expiration_time(expires_in): + """Make expiration time from expires_in. + + :param expires_in: The time in seconds. + """ + if expires_in is None: + return None + return datetime.utcnow() + timedelta(seconds=expires_in) + + +def token_setter( + remote, + token, + secret="", + token_type="", + extra_data=None, + user=None, + refresh_token=None, + expires_at=None, +): """Set token for user. :param remote: The remote application. @@ -127,7 +148,12 @@ def token_setter(remote, token, secret="", token_type="", extra_data=None, user= :returns: A :class:`invenio_oauthclient.models.RemoteToken` instance or ``None``. """ - session[token_session_key(remote.name)] = (token, secret) + session[token_session_key(remote.name)] = ( + token, + secret, + refresh_token, + expires_at.isoformat() if expires_at else None, + ) user = user or current_user # Save token if user is not anonymous (user exists but can be not active at @@ -140,10 +166,17 @@ def token_setter(remote, token, secret="", token_type="", extra_data=None, user= t = RemoteToken.get(uid, cid, token_type=token_type) if t: - t.update_token(token, secret) + t.update_token(token, secret, refresh_token, expires_at) else: t = RemoteToken.create( - uid, cid, token, secret, token_type=token_type, extra_data=extra_data + uid, + cid, + token, + secret, + token_type=token_type, + extra_data=extra_data, + refresh_token=refresh_token, + expires_at=expires_at, ) return t return None @@ -176,7 +209,15 @@ def token_getter(remote, token=""): # Store token and secret in session session[session_key] = remote_token.token() - return session.get(session_key, None) + ret = session.get(session_key, None) + if ret: + if len(ret) == 2: + # no refresh token nor expiration time + return ret[0], ret[1], None, None + if ret[3] is not None: + # refresh token and expiration time + return ret[0], ret[1], ret[2], datetime.fromisoformat(ret[3]) + return ret def token_delete(remote, token=""): diff --git a/invenio_oauthclient/models.py b/invenio_oauthclient/models.py index 56039c1f..e1f6b103 100644 --- a/invenio_oauthclient/models.py +++ b/invenio_oauthclient/models.py @@ -8,6 +8,8 @@ """Models for storing access tokens and links between users and remote apps.""" +from datetime import datetime, timedelta + from flask import current_app # UserIdentity imported for backward compatibility. UserIdentity was originally @@ -119,6 +121,14 @@ class RemoteToken(db.Model, Timestamp): ) """Access token to remote application.""" + refresh_token = db.Column( + EncryptedType(type_in=db.Text, key=_secret_key), nullable=True + ) + """Refresh token to remote application.""" + + expires_at = db.Column(db.DateTime, nullable=True) + """Access token expiration date.""" + secret = db.Column(db.Text(), default="", nullable=False) """Used only by OAuth 1.""" @@ -130,6 +140,16 @@ class RemoteToken(db.Model, Timestamp): ) """SQLAlchemy relationship to RemoteAccount objects.""" + @property + def is_expired(self): + """Check if access token has expired.""" + if not self.expires_at: + return False + + leeway = current_app.config.get("OAUTHCLIENT_TOKEN_EXPIRES_LEEWAY", 10) + expiration_with_leeway = self.expires_at - timedelta(seconds=leeway) + return expiration_with_leeway < datetime.utcnow() + def __repr__(self): """String representation for model.""" return ( @@ -141,18 +161,37 @@ def token(self): """Get token as expected by Flask-OAuthlib.""" return (self.access_token, self.secret) - def update_token(self, token, secret): + def update_token(self, token, secret, refresh_token=None, expires_at=None): """Update token with new values. :param token: The token value. :param secret: The secret key. + :param refresh_token: The refresh token + :param expires_at: Time when the access token expires """ - if self.access_token != token or self.secret != secret: + if ( + self.access_token != token + or self.secret != secret + or self.refresh_token != refresh_token + or self.expiration != expires_at + ): with db.session.begin_nested(): self.access_token = token self.secret = secret + self.refresh_token = refresh_token + self.expires_at = expires_at db.session.add(self) + def refresh_access_token(self): + """Refresh the access token.""" + if not self.refresh_token: + raise ValueError("No refresh token available") + from .handlers.refresh import refresh_access_token + + access_token, refresh_token, secret, expires_at = refresh_access_token(self) + self.update_token(access_token, refresh_token, secret, expires_at) + db.session.commit() + @classmethod def get(cls, user_id, client_id, token_type="", access_token=None): """Get RemoteToken for user. @@ -197,7 +236,17 @@ def get_by_token(cls, client_id, access_token, token_type=""): ) @classmethod - def create(cls, user_id, client_id, token, secret, token_type="", extra_data=None): + def create( + cls, + user_id, + client_id, + token, + secret, + token_type="", + extra_data=None, + refresh_token=None, + expires_at=None, + ): """Create a new access token. .. note:: Creates RemoteAccount as well if it does not exists. @@ -209,6 +258,8 @@ def create(cls, user_id, client_id, token, secret, token_type="", extra_data=Non :param token_type: The token type. (Default: ``''``) :param extra_data: Extra data to set in the remote account if the remote account doesn't exists. (Default: ``None``) + :param refresh_token: The refresh token. + :param expires_at: Expiration of the token :returns: A :class:`invenio_oauthclient.models.RemoteToken` instance. """ @@ -228,6 +279,8 @@ def create(cls, user_id, client_id, token, secret, token_type="", extra_data=Non remote_account=account, access_token=token, secret=secret, + refresh_token=refresh_token, + expires_at=expires_at, ) db.session.add(token) return token diff --git a/invenio_oauthclient/utils.py b/invenio_oauthclient/utils.py index 40f4cee6..722ab6ab 100644 --- a/invenio_oauthclient/utils.py +++ b/invenio_oauthclient/utils.py @@ -9,8 +9,7 @@ """Utility methods.""" from flask import current_app, request, session -from flask_login import current_user -from flask_principal import RoleNeed, UserNeed +from flask_principal import RoleNeed from invenio_db.utils import rebuild_encrypted_properties from itsdangerous import TimedJSONWebSignatureSerializer from uritools import uricompose, urisplit diff --git a/tests/helpers.py b/tests/helpers.py index 7407c45b..02bd3467 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -40,6 +40,11 @@ def mock_remote_get(oauth, remote_app="test", data=None): oauth.remote_apps[remote_app].get = MagicMock(return_value=data) +def mock_remote_http_request(oauth, remote_app="test", data=None): + """Mock the oauth remote get response.""" + oauth.remote_apps[remote_app].http_request = MagicMock(return_value=data) + + def check_redirect_location(resp, loc): """Check response redirect location.""" assert resp._status_code == 302 diff --git a/tests/test_base_handlers.py b/tests/test_base_handlers.py index a419a780..83e70053 100644 --- a/tests/test_base_handlers.py +++ b/tests/test_base_handlers.py @@ -51,4 +51,4 @@ def test_token_getter(remote, models_fixture, app): # Populated RemoteToken RemoteToken.create(user.id, "testkey", "mytoken", "mysecret") oauth_authenticate("dev", user) - assert token_getter(remote) == ("mytoken", "mysecret") + assert token_getter(remote) == ("mytoken", "mysecret", None, None) diff --git a/tests/test_refresh.py b/tests/test_refresh.py new file mode 100644 index 00000000..ae76b500 --- /dev/null +++ b/tests/test_refresh.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# +# This file is part of Invenio. +# Copyright (C) 2024 CESNET z.s.p.o. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Test handlers.""" +import json +from datetime import datetime + +from helpers import mock_remote_http_request + +from invenio_oauthclient.models import RemoteToken + + +def test_refresh(models_fixture, app): + """Test token getter on response from OAuth server.""" + datastore = app.extensions["invenio-accounts"].datastore + existing_email = "existing@inveniosoftware.org" + user = datastore.find_user(email=existing_email) + + rt = RemoteToken.create( + user.id, + "cern_key_changeme", + "mytoken", + "mysecret", + refresh_token="myrefreshtoken", + expires_at=datetime.utcnow(), + ) + assert rt.is_expired is True + + ioc = app.extensions["oauthlib.client"] + mock_remote_http_request( + ioc, + "cern_openid", + [ + None, + json.dumps( + { + "access_token": "newtoken", + "token_type": "bearer", + "expires_in": 1199, + "refresh_token": "newrefreshtoken", + } + ), + ], + ) + + rt.refresh_access_token() + assert rt.is_expired is False + assert rt.access_token == "newtoken" + assert rt.refresh_token == "newrefreshtoken" diff --git a/tests/test_views.py b/tests/test_views.py index 0d7b8740..6c96ec2b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -391,7 +391,12 @@ def test_token_getter_setter(views_fixture, monkeypatch): # Assert if everything is as it should be. from flask import session as flask_session - assert flask_session["oauth_token_full"] == ("test_access_token", "") + assert flask_session["oauth_token_full"] == ( + "test_access_token", + "", + None, + None, + ) t = RemoteToken.get(1, "fullid") assert t.remote_account.client_id == "fullid" @@ -423,7 +428,7 @@ def test_token_getter_setter(views_fixture, monkeypatch): assert RemoteToken.query.count() == 1 val = token_getter(app.extensions["oauthlib.client"].remote_apps["full"]) - assert val == ("new_access_token", "") + assert val == ("new_access_token", "", None, None) # Disconnect account res = c.get( diff --git a/tests/test_views_rest.py b/tests/test_views_rest.py index 27516596..cb23a206 100644 --- a/tests/test_views_rest.py +++ b/tests/test_views_rest.py @@ -385,7 +385,12 @@ def test_token_getter_setter(app_rest, monkeypatch): # Assert if everything is as it should be. from flask import session as flask_session - assert flask_session["oauth_token_full"] == ("test_access_token", "") + assert flask_session["oauth_token_full"] == ( + "test_access_token", + "", + None, + None, + ) t = RemoteToken.get(1, "fullid") assert t.remote_account.client_id == "fullid" @@ -417,7 +422,7 @@ def test_token_getter_setter(app_rest, monkeypatch): assert RemoteToken.query.count() == 1 val = token_getter(app_rest.extensions["oauthlib.client"].remote_apps["full"]) - assert val == ("new_access_token", "") + assert val == ("new_access_token", "", None, None) # Disconnect account res = c.get(