Skip to content

Commit

Permalink
Pass expired token to expired_token_callback
Browse files Browse the repository at this point in the history
Refs #220
  • Loading branch information
vimalloc authored Jan 20, 2019
1 parent 8ba49aa commit ea98fcc
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 30 deletions.
5 changes: 3 additions & 2 deletions examples/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# this function whenever an expired but otherwise valid access
# token attempts to access an endpoint
@jwt.expired_token_loader
def my_expired_token_callback():
def my_expired_token_callback(expired_token):
token_type = expired_token['type']
return jsonify({
'status': 401,
'sub_status': 42,
'msg': 'The token has expired'
'msg': 'The {} token has expired'.format(token_type)
}), 401


Expand Down
2 changes: 1 addition & 1 deletion flask_jwt_extended/default_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def default_user_identity_callback(userdata):
return userdata


def default_expired_token_callback():
def default_expired_token_callback(expired_token):
"""
By default, if an expired token attempts to access a protected endpoint,
we return a generic error message with a 401 status
Expand Down
21 changes: 18 additions & 3 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import datetime
from warnings import warn

from jwt import ExpiredSignatureError, InvalidTokenError, InvalidAudienceError
try:
from flask import _app_ctx_stack as ctx_stack
except ImportError: # pragma: no cover
from flask import _request_ctx_stack as ctx_stack

from flask_jwt_extended.config import config
from flask_jwt_extended.exceptions import (
Expand Down Expand Up @@ -90,7 +95,16 @@ def handle_csrf_error(e):

@app.errorhandler(ExpiredSignatureError)
def handle_expired_error(e):
return self._expired_token_callback()
try:
token = ctx_stack.top.expired_jwt
return self._expired_token_callback(token)
except TypeError:
msg = (
"jwt.expired_token_loader callback now takes the expired token "
"as an additional paramter. Example: expired_callback(token)"
)
warn(msg, DeprecationWarning)
return self._expired_token_callback()

@app.errorhandler(InvalidHeaderError)
def handle_invalid_header_error(e):
Expand Down Expand Up @@ -244,8 +258,9 @@ def expired_token_loader(self, callback):
{"msg": "Token has expired"}
*HINT*: The callback must be a function that takes **zero** arguments, and returns
a *Flask response*.
*HINT*: The callback must be a function that takes **one** argument,
which is a dictionary containing the data for the expired token, and
and returns a *Flask response*.
"""
self._expired_token_callback = callback
return callback
Expand Down
8 changes: 6 additions & 2 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims

def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
user_claims_key, csrf_value=None, audience=None,
leeway=0):
leeway=0, allow_expired=False):
"""
Decodes an encoded JWT
Expand All @@ -126,12 +126,16 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
:param csrf_value: Expected double submit csrf value
:param audience: expected audience in the JWT
:param leeway: optional leeway to add some margin around expiration times
:param allow_expired: Options to ignore exp claim validation in token
:return: Dictionary containing contents of the JWT
"""
options = {}
if allow_expired:
options['verify_exp'] = False

# This call verifies the ext, iat, nbf, and aud claims
data = jwt.decode(encoded_token, secret, algorithms=[algorithm], audience=audience,
leeway=leeway)
leeway=leeway, options=options)

# Make sure that any custom claims we expect in the token are present
if 'jti' not in data:
Expand Down
8 changes: 6 additions & 2 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ def get_jti(encoded_token):
return decode_token(encoded_token).get('jti')


def decode_token(encoded_token, csrf_value=None):
def decode_token(encoded_token, csrf_value=None, allow_expired=False):
"""
Returns the decoded token (python dict) from an encoded JWT. This does all
the checks to insure that the decoded token is valid before returning it.
:param encoded_token: The encoded JWT to decode into a python dict.
:param csrf_value: Expected CSRF double submit value (optional)
:param allow_expired: Options to ignore exp claim validation in token
:return: Dictionary containing contents of the JWT
"""
jwt_manager = _get_jwt_manager()
unverified_claims = jwt.decode(
Expand All @@ -90,6 +92,7 @@ def decode_token(encoded_token, csrf_value=None):
)
warn(msg, DeprecationWarning)
secret = jwt_manager._decode_key_callback(unverified_claims)

return decode_jwt(
encoded_token=encoded_token,
secret=secret,
Expand All @@ -98,7 +101,8 @@ def decode_token(encoded_token, csrf_value=None):
user_claims_key=config.user_claims_key,
csrf_value=csrf_value,
audience=config.audience,
leeway=config.leeway
leeway=config.leeway,
allow_expired=allow_expired
)


Expand Down
29 changes: 18 additions & 11 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from calendar import timegm

from werkzeug.exceptions import BadRequest
from jwt import ExpiredSignatureError

from flask import request
try:
Expand Down Expand Up @@ -191,7 +192,7 @@ def _decode_jwt_from_headers():
raise InvalidHeaderError(msg)
encoded_token = parts[1]

return decode_token(encoded_token)
return encoded_token, None


def _decode_jwt_from_cookies(request_type):
Expand All @@ -213,7 +214,7 @@ def _decode_jwt_from_cookies(request_type):
else:
csrf_value = None

return decode_token(encoded_token, csrf_value=csrf_value)
return encoded_token, csrf_value


def _decode_jwt_from_query_string():
Expand All @@ -222,7 +223,7 @@ def _decode_jwt_from_query_string():
if not encoded_token:
raise NoAuthorizationError('Missing "{}" query paramater'.format(query_param))

return decode_token(encoded_token)
return encoded_token, None


def _decode_jwt_from_json(request_type):
Expand All @@ -241,29 +242,35 @@ def _decode_jwt_from_json(request_type):
except BadRequest:
raise NoAuthorizationError('Missing "{}" key in json data.'.format(token_key))

return decode_token(encoded_token)
return encoded_token, None


def _decode_jwt_from_request(request_type):
# All the places we can get a JWT from in this request
decode_functions = []
get_encoded_token_functions = []
if config.jwt_in_cookies:
decode_functions.append(lambda: _decode_jwt_from_cookies(request_type))
get_encoded_token_functions.append(lambda: _decode_jwt_from_cookies(request_type))
if config.jwt_in_query_string:
decode_functions.append(_decode_jwt_from_query_string)
get_encoded_token_functions.append(_decode_jwt_from_query_string)
if config.jwt_in_headers:
decode_functions.append(_decode_jwt_from_headers)
get_encoded_token_functions.append(_decode_jwt_from_headers)
if config.jwt_in_json:
decode_functions.append(lambda: _decode_jwt_from_json(request_type))
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(request_type))

# Try to find the token from one of these locations. It only needs to exist
# in one place to be valid (not every location).
errors = []
decoded_token = None
for decode_function in decode_functions:
for get_encoded_token_function in get_encoded_token_functions:
try:
decoded_token = decode_function()
encoded_token, csrf_token = get_encoded_token_function()
decoded_token = decode_token(encoded_token, csrf_token)
break
except ExpiredSignatureError:
# Save the expired token so we can access it in a callback later
expired_data = decode_token(encoded_token, csrf_token, allow_expired=True)
ctx_stack.top.expired_jwt = expired_data
raise
except NoAuthorizationError as e:
errors.append(str(e))

Expand Down
17 changes: 13 additions & 4 deletions tests/test_decode_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ def default_access_token(app):

@pytest.fixture(scope='function')
def patch_datetime_now(monkeypatch):

DATE_IN_FUTURE = datetime.utcnow() + timedelta(seconds=30)
date_in_future = datetime.utcnow() + timedelta(seconds=30)

class mydatetime(datetime):
@classmethod
def utcnow(cls):
return DATE_IN_FUTURE

return date_in_future
monkeypatch.setattr(__name__ + ".datetime", mydatetime)
monkeypatch.setattr("datetime.datetime", mydatetime)

Expand Down Expand Up @@ -116,6 +114,17 @@ def test_expired_token(app):
decode_token(refresh_token)


def test_allow_expired_token(app):
with app.test_request_context():
delta = timedelta(minutes=-5)
access_token = create_access_token('username', expires_delta=delta)
refresh_token = create_refresh_token('username', expires_delta=delta)
for token in (access_token, refresh_token):
decoded = decode_token(token, allow_expired=True)
assert decoded['identity'] == 'username'
assert 'exp' in decoded


def test_never_expire_token(app):
with app.test_request_context():
access_token = create_access_token('username', expires_delta=False)
Expand Down
28 changes: 23 additions & 5 deletions tests/test_view_decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import warnings
from datetime import timedelta
from flask import Flask, jsonify

Expand Down Expand Up @@ -246,14 +247,31 @@ def test_expired_token(app):
assert response.status_code == 401
assert response.get_json() == {'msg': 'Token has expired'}

# Test custom response
# Test depreciated custom response
@jwtM.expired_token_loader
def custom_response():
def depreciated_custom_response():
return jsonify(msg='foobar'), 201

response = test_client.get(url, headers=make_headers(token))
assert response.status_code == 201
assert response.get_json() == {'msg': 'foobar'}
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
response = test_client.get(url, headers=make_headers(token))
assert response.status_code == 201
assert response.get_json() == {'msg': 'foobar'}
assert w[0].category == DeprecationWarning

# Test new custom response
@jwtM.expired_token_loader
def custom_response(token):
assert token['identity'] == 'username'
assert token['type'] == 'access'
return jsonify(msg='foobar'), 201

warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
response = test_client.get(url, headers=make_headers(token))
assert response.status_code == 201
assert response.get_json() == {'msg': 'foobar'}
assert len(w) == 0


def test_no_token(app):
Expand Down

0 comments on commit ea98fcc

Please sign in to comment.