diff --git a/docs/options.rst b/docs/options.rst index 9d2a100f..ffd3700b 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -39,6 +39,8 @@ General Options: Defaults to ``'identity'`` for legacy reasons. ``JWT_USER_CLAIMS`` Claim in the tokens that is used to store user claims. Defaults to ``'user_claims'``. +``JWT_CLAIMS_IN_REFRESH_TOKEN`` If user claims should be included in refresh tokens. + Defaults to ``False``. ================================= ========================================= diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index bc4542bb..9d71f492 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -247,6 +247,10 @@ def identity_claim_key(self): def user_claims_key(self): return current_app.config['JWT_USER_CLAIMS'] + @property + def user_claims_in_refresh_token(self): + return current_app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] + @property def exempt_methods(self): return {"OPTIONS"} diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index 531496ab..70b6ef58 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -187,6 +187,8 @@ def _set_default_configuration_options(app): app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity') app.config.setdefault('JWT_USER_CLAIMS', 'user_claims') + app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False) + def user_claims_loader(self, callback): """ This decorator sets the callback function for adding custom claims to an @@ -375,13 +377,20 @@ def _create_refresh_token(self, identity, expires_delta=None): if expires_delta is None: expires_delta = config.refresh_expires + if config.user_claims_in_refresh_token: + user_claims = self._user_claims_callback(identity) + else: + user_claims = None + refresh_token = encode_refresh_token( identity=self._user_identity_callback(identity), secret=config.encode_key, algorithm=config.algorithm, expires_delta=expires_delta, + user_claims=user_claims, csrf=config.csrf_protect, identity_claim_key=config.identity_claim_key, + user_claims_key=config.user_claims_key, json_encoder=config.json_encoder ) return refresh_token diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 9aa0e14e..7f500d9b 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -77,8 +77,9 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, json_encoder=json_encoder) -def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, - identity_claim_key, json_encoder=None): +def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims, + csrf, identity_claim_key, user_claims_key, + json_encoder=None): """ Creates a new encoded (utf-8) refresh token. @@ -88,15 +89,23 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, :param expires_delta: How far in the future this token should expire (set to False to disable expiration) :type expires_delta: datetime.timedelta or False + :param user_claims: Custom claims to include in this token. This data must + be json serializable :param csrf: Whether to include a csrf double submit claim in this token (boolean) :param identity_claim_key: Which key should be used to store the identity + :param user_claims_key: Which key should be used to store the user claims :return: Encoded refresh token """ token_data = { identity_claim_key: identity, 'type': 'refresh', } + + # Don't add extra data to the token if user_claims is empty. + if user_claims: + token_data[user_claims_key] = user_claims + if csrf: token_data['csrf'] = _create_csrf_token() return _encode_jwt(token_data, expires_delta, secret, algorithm, @@ -129,8 +138,8 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, if data['type'] == 'access': if 'fresh' not in data: raise JWTDecodeError("Missing claim: fresh") - if user_claims_key not in data: - data[user_claims_key] = {} + if user_claims_key not in data: + data[user_claims_key] = {} if csrf_value: if 'csrf' not in data: raise JWTDecodeError("Missing claim: csrf") diff --git a/tests/test_config.py b/tests/test_config.py index 613cdf83..e1ef5ccf 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -61,6 +61,8 @@ def test_default_configs(app): assert config.identity_claim_key == 'identity' assert config.user_claims_key == 'user_claims' + assert config.user_claims_in_refresh_token is False + assert config.json_encoder is app.json_encoder @@ -100,6 +102,8 @@ def test_override_configs(app): app.config['JWT_IDENTITY_CLAIM'] = 'foo' app.config['JWT_USER_CLAIMS'] = 'bar' + app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True + class CustomJSONEncoder(JSONEncoder): pass @@ -148,6 +152,8 @@ class CustomJSONEncoder(JSONEncoder): assert config.identity_claim_key == 'foo' assert config.user_claims_key == 'bar' + assert config.user_claims_in_refresh_token is True + assert config.json_encoder is CustomJSONEncoder diff --git a/tests/test_user_claims_loader.py b/tests/test_user_claims_loader.py index 51f1900d..61990fef 100644 --- a/tests/test_user_claims_loader.py +++ b/tests/test_user_claims_loader.py @@ -3,7 +3,7 @@ from flask_jwt_extended import ( JWTManager, create_access_token, jwt_required, get_jwt_claims, - decode_token + decode_token, jwt_refresh_token_required, create_refresh_token ) from tests.utils import get_jwt_manager, make_headers @@ -19,6 +19,11 @@ def app(): def get_claims(): return jsonify(get_jwt_claims()) + @app.route('/protected2', methods=['GET']) + @jwt_refresh_token_required + def get_refresh_claims(): + return jsonify(get_jwt_claims()) + return app @@ -99,3 +104,36 @@ def add_claims(identity): response = test_client.get('/protected', headers=make_headers(access_token)) assert response.get_json() == {'foo': 'bar'} assert response.status_code == 200 + + +def test_user_claim_not_in_refresh_token(app): + jwt = get_jwt_manager(app) + + @jwt.user_claims_loader + def add_claims(identity): + return {'foo': 'bar'} + + with app.test_request_context(): + refresh_token = create_refresh_token('username') + + test_client = app.test_client() + response = test_client.get('/protected2', headers=make_headers(refresh_token)) + assert response.get_json() == {} + assert response.status_code == 200 + + +def test_user_claim_in_refresh_token(app): + app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True + jwt = get_jwt_manager(app) + + @jwt.user_claims_loader + def add_claims(identity): + return {'foo': 'bar'} + + with app.test_request_context(): + refresh_token = create_refresh_token('username') + + test_client = app.test_client() + response = test_client.get('/protected2', headers=make_headers(refresh_token)) + assert response.get_json() == {'foo': 'bar'} + assert response.status_code == 200