diff --git a/examples/loaders.py b/examples/loaders.py index 78cd9a5e..6a038022 100644 --- a/examples/loaders.py +++ b/examples/loaders.py @@ -13,11 +13,12 @@ # this function whenever an expired but otherwise valid access # token attempts to access an endpoint @jwt.expired_token_loader -def my_expired_token_callback(): +def my_expired_token_callback(expired_token): + token_type = expired_token['type'] return jsonify({ 'status': 401, 'sub_status': 42, - 'msg': 'The token has expired' + 'msg': 'The {} token has expired'.format(token_type) }), 401 diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index 7c85e8cc..fc031769 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -34,7 +34,7 @@ def default_user_identity_callback(userdata): return userdata -def default_expired_token_callback(): +def default_expired_token_callback(expired_token): """ By default, if an expired token attempts to access a protected endpoint, we return a generic error message with a 401 status diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index 8cd743cd..23ce0db1 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -1,6 +1,11 @@ import datetime +from warnings import warn from jwt import ExpiredSignatureError, InvalidTokenError, InvalidAudienceError +try: + from flask import _app_ctx_stack as ctx_stack +except ImportError: # pragma: no cover + from flask import _request_ctx_stack as ctx_stack from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import ( @@ -90,7 +95,16 @@ def handle_csrf_error(e): @app.errorhandler(ExpiredSignatureError) def handle_expired_error(e): - return self._expired_token_callback() + try: + token = ctx_stack.top.expired_jwt + return self._expired_token_callback(token) + except TypeError: + msg = ( + "jwt.expired_token_loader callback now takes the expired token " + "as an additional paramter. Example: expired_callback(token)" + ) + warn(msg, DeprecationWarning) + return self._expired_token_callback() @app.errorhandler(InvalidHeaderError) def handle_invalid_header_error(e): @@ -244,8 +258,9 @@ def expired_token_loader(self, callback): {"msg": "Token has expired"} - *HINT*: The callback must be a function that takes **zero** arguments, and returns - a *Flask response*. + *HINT*: The callback must be a function that takes **one** argument, + which is a dictionary containing the data for the expired token, and + and returns a *Flask response*. """ self._expired_token_callback = callback return callback diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 563a4d86..33561f3c 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -114,7 +114,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, user_claims_key, csrf_value=None, audience=None, - leeway=0): + leeway=0, allow_expired=False): """ Decodes an encoded JWT @@ -126,12 +126,16 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, :param csrf_value: Expected double submit csrf value :param audience: expected audience in the JWT :param leeway: optional leeway to add some margin around expiration times + :param allow_expired: Options to ignore exp claim validation in token :return: Dictionary containing contents of the JWT """ + options = {} + if allow_expired: + options['verify_exp'] = False # This call verifies the ext, iat, nbf, and aud claims data = jwt.decode(encoded_token, secret, algorithms=[algorithm], audience=audience, - leeway=leeway) + leeway=leeway, options=options) # Make sure that any custom claims we expect in the token are present if 'jti' not in data: diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 8e27c348..ff300a33 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -65,13 +65,15 @@ def get_jti(encoded_token): return decode_token(encoded_token).get('jti') -def decode_token(encoded_token, csrf_value=None): +def decode_token(encoded_token, csrf_value=None, allow_expired=False): """ Returns the decoded token (python dict) from an encoded JWT. This does all the checks to insure that the decoded token is valid before returning it. :param encoded_token: The encoded JWT to decode into a python dict. :param csrf_value: Expected CSRF double submit value (optional) + :param allow_expired: Options to ignore exp claim validation in token + :return: Dictionary containing contents of the JWT """ jwt_manager = _get_jwt_manager() unverified_claims = jwt.decode( @@ -90,6 +92,7 @@ def decode_token(encoded_token, csrf_value=None): ) warn(msg, DeprecationWarning) secret = jwt_manager._decode_key_callback(unverified_claims) + return decode_jwt( encoded_token=encoded_token, secret=secret, @@ -98,7 +101,8 @@ def decode_token(encoded_token, csrf_value=None): user_claims_key=config.user_claims_key, csrf_value=csrf_value, audience=config.audience, - leeway=config.leeway + leeway=config.leeway, + allow_expired=allow_expired ) diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 509ad161..941fa7d9 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -3,6 +3,7 @@ from calendar import timegm from werkzeug.exceptions import BadRequest +from jwt import ExpiredSignatureError from flask import request try: @@ -191,7 +192,7 @@ def _decode_jwt_from_headers(): raise InvalidHeaderError(msg) encoded_token = parts[1] - return decode_token(encoded_token) + return encoded_token, None def _decode_jwt_from_cookies(request_type): @@ -213,7 +214,7 @@ def _decode_jwt_from_cookies(request_type): else: csrf_value = None - return decode_token(encoded_token, csrf_value=csrf_value) + return encoded_token, csrf_value def _decode_jwt_from_query_string(): @@ -222,7 +223,7 @@ def _decode_jwt_from_query_string(): if not encoded_token: raise NoAuthorizationError('Missing "{}" query paramater'.format(query_param)) - return decode_token(encoded_token) + return encoded_token, None def _decode_jwt_from_json(request_type): @@ -241,29 +242,35 @@ def _decode_jwt_from_json(request_type): except BadRequest: raise NoAuthorizationError('Missing "{}" key in json data.'.format(token_key)) - return decode_token(encoded_token) + return encoded_token, None def _decode_jwt_from_request(request_type): # All the places we can get a JWT from in this request - decode_functions = [] + get_encoded_token_functions = [] if config.jwt_in_cookies: - decode_functions.append(lambda: _decode_jwt_from_cookies(request_type)) + get_encoded_token_functions.append(lambda: _decode_jwt_from_cookies(request_type)) if config.jwt_in_query_string: - decode_functions.append(_decode_jwt_from_query_string) + get_encoded_token_functions.append(_decode_jwt_from_query_string) if config.jwt_in_headers: - decode_functions.append(_decode_jwt_from_headers) + get_encoded_token_functions.append(_decode_jwt_from_headers) if config.jwt_in_json: - decode_functions.append(lambda: _decode_jwt_from_json(request_type)) + get_encoded_token_functions.append(lambda: _decode_jwt_from_json(request_type)) # Try to find the token from one of these locations. It only needs to exist # in one place to be valid (not every location). errors = [] decoded_token = None - for decode_function in decode_functions: + for get_encoded_token_function in get_encoded_token_functions: try: - decoded_token = decode_function() + encoded_token, csrf_token = get_encoded_token_function() + decoded_token = decode_token(encoded_token, csrf_token) break + except ExpiredSignatureError: + # Save the expired token so we can access it in a callback later + expired_data = decode_token(encoded_token, csrf_token, allow_expired=True) + ctx_stack.top.expired_jwt = expired_data + raise except NoAuthorizationError as e: errors.append(str(e)) diff --git a/tests/test_decode_tokens.py b/tests/test_decode_tokens.py index 2ca15dcb..cbb513cb 100644 --- a/tests/test_decode_tokens.py +++ b/tests/test_decode_tokens.py @@ -43,14 +43,12 @@ def default_access_token(app): @pytest.fixture(scope='function') def patch_datetime_now(monkeypatch): - - DATE_IN_FUTURE = datetime.utcnow() + timedelta(seconds=30) + date_in_future = datetime.utcnow() + timedelta(seconds=30) class mydatetime(datetime): @classmethod def utcnow(cls): - return DATE_IN_FUTURE - + return date_in_future monkeypatch.setattr(__name__ + ".datetime", mydatetime) monkeypatch.setattr("datetime.datetime", mydatetime) @@ -116,6 +114,17 @@ def test_expired_token(app): decode_token(refresh_token) +def test_allow_expired_token(app): + with app.test_request_context(): + delta = timedelta(minutes=-5) + access_token = create_access_token('username', expires_delta=delta) + refresh_token = create_refresh_token('username', expires_delta=delta) + for token in (access_token, refresh_token): + decoded = decode_token(token, allow_expired=True) + assert decoded['identity'] == 'username' + assert 'exp' in decoded + + def test_never_expire_token(app): with app.test_request_context(): access_token = create_access_token('username', expires_delta=False) diff --git a/tests/test_view_decorators.py b/tests/test_view_decorators.py index ae366075..8ec65413 100644 --- a/tests/test_view_decorators.py +++ b/tests/test_view_decorators.py @@ -1,4 +1,5 @@ import pytest +import warnings from datetime import timedelta from flask import Flask, jsonify @@ -246,14 +247,31 @@ def test_expired_token(app): assert response.status_code == 401 assert response.get_json() == {'msg': 'Token has expired'} - # Test custom response + # Test depreciated custom response @jwtM.expired_token_loader - def custom_response(): + def depreciated_custom_response(): return jsonify(msg='foobar'), 201 - response = test_client.get(url, headers=make_headers(token)) - assert response.status_code == 201 - assert response.get_json() == {'msg': 'foobar'} + warnings.simplefilter("always") + with warnings.catch_warnings(record=True) as w: + response = test_client.get(url, headers=make_headers(token)) + assert response.status_code == 201 + assert response.get_json() == {'msg': 'foobar'} + assert w[0].category == DeprecationWarning + + # Test new custom response + @jwtM.expired_token_loader + def custom_response(token): + assert token['identity'] == 'username' + assert token['type'] == 'access' + return jsonify(msg='foobar'), 201 + + warnings.simplefilter("always") + with warnings.catch_warnings(record=True) as w: + response = test_client.get(url, headers=make_headers(token)) + assert response.status_code == 201 + assert response.get_json() == {'msg': 'foobar'} + assert len(w) == 0 def test_no_token(app):