Skip to content

Commit

Permalink
Include user claims in refresh tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
Eliška Roubalová committed Jun 4, 2018
1 parent 02c9fc6 commit 664f8b4
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ General Options:
Defaults to ``'identity'`` for legacy reasons.
``JWT_USER_CLAIMS`` Claim in the tokens that is used to store user claims.
Defaults to ``'user_claims'``.
``JWT_CLAIMS_IN_REFRESH_TOKEN`` If user claims should be included in refresh tokens.
Defaults to ``False``.
================================= =========================================


Expand Down
4 changes: 4 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def identity_claim_key(self):
def user_claims_key(self):
return current_app.config['JWT_USER_CLAIMS']

@property
def user_claims_in_refresh_token(self):
return current_app.config['JWT_CLAIMS_IN_REFRESH_TOKEN']

@property
def exempt_methods(self):
return {"OPTIONS"}
Expand Down
9 changes: 9 additions & 0 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def _set_default_configuration_options(app):
app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')

app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)

def user_claims_loader(self, callback):
"""
This decorator sets the callback function for adding custom claims to an
Expand Down Expand Up @@ -375,13 +377,20 @@ def _create_refresh_token(self, identity, expires_delta=None):
if expires_delta is None:
expires_delta = config.refresh_expires

if config.user_claims_in_refresh_token:
user_claims = self._user_claims_callback(identity)
else:
user_claims = None

refresh_token = encode_refresh_token(
identity=self._user_identity_callback(identity),
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=expires_delta,
user_claims=user_claims,
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
json_encoder=config.json_encoder
)
return refresh_token
Expand Down
17 changes: 13 additions & 4 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
json_encoder=json_encoder)


def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf,
identity_claim_key, json_encoder=None):
def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims,
csrf, identity_claim_key, user_claims_key,
json_encoder=None):
"""
Creates a new encoded (utf-8) refresh token.
Expand All @@ -88,15 +89,23 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf,
:param expires_delta: How far in the future this token should expire
(set to False to disable expiration)
:type expires_delta: datetime.timedelta or False
:param user_claims: Custom claims to include in this token. This data must
be json serializable
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim_key: Which key should be used to store the identity
:param user_claims_key: Which key should be used to store the user claims
:return: Encoded refresh token
"""
token_data = {
identity_claim_key: identity,
'type': 'refresh',
}

# Don't add extra data to the token if user_claims is empty.
if user_claims:
token_data[user_claims_key] = user_claims

if csrf:
token_data['csrf'] = _create_csrf_token()
return _encode_jwt(token_data, expires_delta, secret, algorithm,
Expand Down Expand Up @@ -129,8 +138,8 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
if data['type'] == 'access':
if 'fresh' not in data:
raise JWTDecodeError("Missing claim: fresh")
if user_claims_key not in data:
data[user_claims_key] = {}
if user_claims_key not in data:
data[user_claims_key] = {}
if csrf_value:
if 'csrf' not in data:
raise JWTDecodeError("Missing claim: csrf")
Expand Down
6 changes: 6 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def test_default_configs(app):
assert config.identity_claim_key == 'identity'
assert config.user_claims_key == 'user_claims'

assert config.user_claims_in_refresh_token is False

assert config.json_encoder is app.json_encoder


Expand Down Expand Up @@ -100,6 +102,8 @@ def test_override_configs(app):
app.config['JWT_IDENTITY_CLAIM'] = 'foo'
app.config['JWT_USER_CLAIMS'] = 'bar'

app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True

class CustomJSONEncoder(JSONEncoder):
pass

Expand Down Expand Up @@ -148,6 +152,8 @@ class CustomJSONEncoder(JSONEncoder):
assert config.identity_claim_key == 'foo'
assert config.user_claims_key == 'bar'

assert config.user_claims_in_refresh_token is True

assert config.json_encoder is CustomJSONEncoder


Expand Down
40 changes: 39 additions & 1 deletion tests/test_user_claims_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from flask_jwt_extended import (
JWTManager, create_access_token, jwt_required, get_jwt_claims,
decode_token
decode_token, jwt_refresh_token_required, create_refresh_token
)
from tests.utils import get_jwt_manager, make_headers

Expand All @@ -19,6 +19,11 @@ def app():
def get_claims():
return jsonify(get_jwt_claims())

@app.route('/protected2', methods=['GET'])
@jwt_refresh_token_required
def get_refresh_claims():
return jsonify(get_jwt_claims())

return app


Expand Down Expand Up @@ -99,3 +104,36 @@ def add_claims(identity):
response = test_client.get('/protected', headers=make_headers(access_token))
assert response.get_json() == {'foo': 'bar'}
assert response.status_code == 200


def test_user_claim_not_in_refresh_token(app):
jwt = get_jwt_manager(app)

@jwt.user_claims_loader
def add_claims(identity):
return {'foo': 'bar'}

with app.test_request_context():
refresh_token = create_refresh_token('username')

test_client = app.test_client()
response = test_client.get('/protected2', headers=make_headers(refresh_token))
assert response.get_json() == {}
assert response.status_code == 200


def test_user_claim_in_refresh_token(app):
app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True
jwt = get_jwt_manager(app)

@jwt.user_claims_loader
def add_claims(identity):
return {'foo': 'bar'}

with app.test_request_context():
refresh_token = create_refresh_token('username')

test_client = app.test_client()
response = test_client.get('/protected2', headers=make_headers(refresh_token))
assert response.get_json() == {'foo': 'bar'}
assert response.status_code == 200

0 comments on commit 664f8b4

Please sign in to comment.