Skip to content

Commit

Permalink
Make token decoding more flexible
Browse files Browse the repository at this point in the history
Fixes #208
  • Loading branch information
acrossen authored and vimalloc committed Dec 7, 2018
1 parent 6fe88c7 commit 234aa6e
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 28 deletions.
3 changes: 3 additions & 0 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ General Options:
``JWT_ERROR_MESSAGE_KEY`` The key of the error message in a JSON error response when using
the default error handlers.
Defaults to ``'msg'``.
``JWT_DECODE_AUDIENCE`` The audience you expect in a JWT when decoding it.
If this option differs from the 'aud' claim in a JWT, the ``'invalid_token_callback'`` is invoked.
Defaults to ``'None'``.
================================= =========================================


Expand Down
4 changes: 4 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,5 +278,9 @@ def error_msg_key(self):
def json_encoder(self):
return current_app.json_encoder

@property
def audience(self):
return current_app.config['JWT_DECODE_AUDIENCE']


config = _Config()
2 changes: 1 addition & 1 deletion flask_jwt_extended/default_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def default_verify_claims_failed_callback():
return jsonify({config.error_msg_key: 'User claims verification failed'}), 400


def default_decode_key_callback(claims):
def default_decode_key_callback(claims, headers):
"""
By default, the decode key specified via the JWT_SECRET_KEY or
JWT_PUBLIC_KEY settings will be used to decode all tokens
Expand Down
14 changes: 10 additions & 4 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime

from jwt import ExpiredSignatureError, InvalidTokenError
from jwt import ExpiredSignatureError, InvalidTokenError, InvalidAudienceError

from flask_jwt_extended.config import config
from flask_jwt_extended.exceptions import (
Expand Down Expand Up @@ -108,6 +108,10 @@ def handle_jwt_decode_error(e):
def handle_wrong_token_error(e):
return self._invalid_token_callback(str(e))

@app.errorhandler(InvalidAudienceError)
def handle_invalid_audience_error(e):
return self._invalid_token_callback(str(e))

@app.errorhandler(RevokedTokenError)
def handle_revoked_token_error(e):
return self._revoked_token_callback()
Expand Down Expand Up @@ -192,6 +196,7 @@ def _set_default_configuration_options(app):

app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')
app.config.setdefault('JWT_DECODE_AUDIENCE', None)

app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)

Expand Down Expand Up @@ -390,9 +395,10 @@ def decode_key_loader(self, callback):
The default implementation returns the decode key specified by
`JWT_SECRET_KEY` or `JWT_PUBLIC_KEY`, depending on the signing algorithm.
*HINT*: The callback function must be a function that takes only **one** argument,
which is the unverified claims of the jwt (dictionary) and must return a *string*
which is the decode key to verify the token.
*HINT*: The callback function should be a function that takes
**two** arguments, which are the unverified claims and headers of the jwt
(dictionaries). The function must return a *string* which is the decode key
in PEM format to verify the token.
"""
self._decode_key_callback = callback
return callback
Expand Down
17 changes: 10 additions & 7 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _create_csrf_token():

def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
json_encoder=None):
uid = str(uuid.uuid4())
uid = _create_csrf_token()
now = datetime.datetime.utcnow()
token_data = {
'iat': now,
Expand Down Expand Up @@ -113,7 +113,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims


def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
user_claims_key, csrf_value=None):
user_claims_key, csrf_value=None, audience=None):
"""
Decodes an encoded JWT
Expand All @@ -123,21 +123,24 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
:param identity_claim_key: expected key that contains the identity
:param user_claims_key: expected key that contains the user claims
:param csrf_value: Expected double submit csrf value
:param audience: expected audience in the JWT
:return: Dictionary containing contents of the JWT
"""
# This call verifies the ext, iat, and nbf claims
data = jwt.decode(encoded_token, secret, algorithms=[algorithm])
# This call verifies the ext, iat, nbf, and aud claims
data = jwt.decode(encoded_token, secret, algorithms=[algorithm], audience=audience)

# Make sure that any custom claims we expect in the token are present
if 'jti' not in data:
raise JWTDecodeError("Missing claim: jti")
data['jti'] = None
if identity_claim_key not in data:
raise JWTDecodeError("Missing claim: {}".format(identity_claim_key))
if 'type' not in data or data['type'] not in ('refresh', 'access'):
if 'type' not in data:
data['type'] = 'access'
if data['type'] not in ('refresh', 'access'):
raise JWTDecodeError("Missing or invalid claim: type")
if data['type'] == 'access':
if 'fresh' not in data:
raise JWTDecodeError("Missing claim: fresh")
data['fresh'] = False
if user_claims_key not in data:
data[user_claims_key] = {}
if csrf_value:
Expand Down
17 changes: 15 additions & 2 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from flask import current_app
from werkzeug.local import LocalProxy
from warnings import warn

try:
from flask import _app_ctx_stack as ctx_stack
Expand Down Expand Up @@ -76,14 +77,26 @@ def decode_token(encoded_token, csrf_value=None):
unverified_claims = jwt.decode(
encoded_token, verify=False, algorithms=config.algorithm
)
secret = jwt_manager._decode_key_callback(unverified_claims)
unverified_headers = jwt.get_unverified_header(encoded_token)
# Attempt to call callback with both claims and headers, but fallback to just claims
# for backwards compatibility
try:
secret = jwt_manager._decode_key_callback(unverified_claims, unverified_headers)
except TypeError:
msg = (
"The single-argument (unverified_claims) form of decode_key_callback is deprecated. "
"Update your code to use the two-argument form (unverified_claims, unverified_headers)."
)
warn(msg, DeprecationWarning)
secret = jwt_manager._decode_key_callback(unverified_claims)
return decode_jwt(
encoded_token=encoded_token,
secret=secret,
algorithm=config.algorithm,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
csrf_value=csrf_value
csrf_value=csrf_value,
audience=config.audience
)


Expand Down
73 changes: 60 additions & 13 deletions tests/test_decode_tokens.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import jwt
import pytest
from datetime import timedelta
from datetime import datetime, timedelta
import warnings

from flask import Flask
from jwt import ExpiredSignatureError, InvalidSignatureError
from jwt import ExpiredSignatureError, InvalidSignatureError, InvalidAudienceError

from flask_jwt_extended import (
JWTManager, create_access_token, decode_token, create_refresh_token,
Expand Down Expand Up @@ -54,16 +55,29 @@ def empty_user_loader_return(identity):
assert config.user_claims_key in extension_decoded


@pytest.mark.parametrize("missing_claim", ['jti', 'type', 'identity', 'fresh', 'csrf'])
def test_missing_jti_claim(app, default_access_token, missing_claim):
del default_access_token[missing_claim]
@pytest.mark.parametrize("missing_claims", ['identity', 'csrf'])
def test_missing_claims(app, default_access_token, missing_claims):
del default_access_token[missing_claims]
missing_jwt_token = encode_token(app, default_access_token)

with pytest.raises(JWTDecodeError):
with app.test_request_context():
decode_token(missing_jwt_token, csrf_value='abcd')


def test_default_decode_token_values(app, default_access_token):
del default_access_token['type']
del default_access_token['jti']
del default_access_token['fresh']
token = encode_token(app, default_access_token)

with app.test_request_context():
decoded = decode_token(token)
assert decoded['type'] == 'access'
assert decoded['jti'] is None
assert decoded['fresh'] is False


def test_bad_token_type(app, default_access_token):
default_access_token['type'] = 'banana'
bad_type_token = encode_token(app, default_access_token)
Expand Down Expand Up @@ -123,19 +137,36 @@ def test_encode_decode_callback_values(app, default_access_token):
jwtM = get_jwt_manager(app)
app.config['JWT_SECRET_KEY'] = 'foobarbaz'
with app.test_request_context():
assert jwtM._decode_key_callback({}) == 'foobarbaz'
assert jwtM._decode_key_callback({}, {}) == 'foobarbaz'
assert jwtM._encode_key_callback({}) == 'foobarbaz'

@jwtM.decode_key_loader
def get_decode_key_1(claims):
@jwtM.encode_key_loader
def get_encode_key_1(identity):
return 'different secret'
assert jwtM._encode_key_callback('') == 'different secret'

@jwtM.encode_key_loader
def get_decode_key_2(identity):
@jwtM.decode_key_loader
def get_decode_key_1(claims, headers):
return 'different secret'
assert jwtM._decode_key_callback({}, {}) == 'different secret'

assert jwtM._decode_key_callback({}) == 'different secret'
assert jwtM._encode_key_callback('') == 'different secret'

def test_legacy_decode_key_callback(app, default_access_token):
jwtM = get_jwt_manager(app)
app.config['JWT_SECRET_KEY'] = 'foobarbaz'

# test decode key callback with one argument (backwards compatibility)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

@jwtM.decode_key_loader
def get_decode_key_legacy(claims):
return 'foobarbaz'
with app.test_request_context():
token = encode_token(app, default_access_token)
decode_token(token)
assert len(w) == 1
assert issubclass(w[-1].category, DeprecationWarning)


def test_custom_encode_decode_key_callbacks(app, default_access_token):
Expand All @@ -157,7 +188,7 @@ def get_encode_key_1(identity):
decode_token(token)

@jwtM.decode_key_loader
def get_decode_key_1(claims):
def get_decode_key_1(claims, headers):
assert claims['identity'] == 'username'
return 'different secret'

Expand All @@ -166,3 +197,19 @@ def get_decode_key_1(claims):
decode_token(token)
token = create_refresh_token('username')
decode_token(token)


def test_valid_aud(app, default_access_token):
app.config['JWT_DECODE_AUDIENCE'] = 'foo'

default_access_token['aud'] = 'bar'
invalid_token = encode_token(app, default_access_token)
with pytest.raises(InvalidAudienceError):
with app.test_request_context():
decode_token(invalid_token)

default_access_token['aud'] = 'foo'
valid_token = encode_token(app, default_access_token)
with app.test_request_context():
decoded = decode_token(valid_token)
assert decoded['aud'] == 'foo'
26 changes: 25 additions & 1 deletion tests/test_view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,31 @@ def test_jwt_missing_claims(app):

response = test_client.get(url, headers=make_headers(token))
assert response.status_code == 422
assert response.get_json() == {'msg': 'Missing claim: jti'}
assert response.get_json() == {'msg': 'Missing claim: identity'}


def test_jwt_invalid_audience(app):
url = '/protected'
jwtM = get_jwt_manager(app)
test_client = app.test_client()

# No audience claim expected or provided - OK
access_token = encode_token(app, {'identity': 'me'})
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 200

# Audience claim expected and not provided - not OK
app.config['JWT_DECODE_AUDIENCE'] = 'my_audience'
access_token = encode_token(app, {'identity': 'me'})
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 422
assert response.get_json() == {'msg': 'Token is missing the "aud" claim'}

# Audience claim still expected and wrong one provided - not OK
access_token = encode_token(app, {'aud': 'different_audience', 'identity': 'me'})
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 422
assert response.get_json() == {'msg': 'Invalid audience'}


def test_expired_token(app):
Expand Down

0 comments on commit 234aa6e

Please sign in to comment.