diff --git a/examples/database_blacklist/app.py b/examples/database_blacklist/app.py index f562a178..b321a4ed 100644 --- a/examples/database_blacklist/app.py +++ b/examples/database_blacklist/app.py @@ -56,8 +56,8 @@ def login(): refresh_token = create_refresh_token(identity=username) # Store the tokens in our store with a status of not currently revoked. - add_token_to_database(access_token) - add_token_to_database(refresh_token) + add_token_to_database(access_token, app.config['JWT_IDENTITY_CLAIM']) + add_token_to_database(refresh_token, app.config['JWT_IDENTITY_CLAIM']) ret = { 'access_token': access_token, @@ -72,7 +72,7 @@ def refresh(): # Do the same thing that we did in the login endpoint here current_user = get_jwt_identity() access_token = create_access_token(identity=current_user) - add_token_to_database(access_token) + add_token_to_database(access_token, app.config['JWT_IDENTITY_CLAIM']) return jsonify({'access_token': access_token}), 201 # Provide a way for a user to look at their tokens diff --git a/examples/database_blacklist/blacklist_helpers.py b/examples/database_blacklist/blacklist_helpers.py index 90cb3e04..960445e8 100644 --- a/examples/database_blacklist/blacklist_helpers.py +++ b/examples/database_blacklist/blacklist_helpers.py @@ -16,14 +16,15 @@ def _epoch_utc_to_datetime(epoch_utc): return datetime.fromtimestamp(epoch_utc) -def add_token_to_database(encoded_token): +def add_token_to_database(encoded_token, identity_claim): """ Adds a new token to the database. It is not revoked when it is added. + :param identity_claim: """ decoded_token = decode_token(encoded_token) jti = decoded_token['jti'] token_type = decoded_token['type'] - user_identity = decoded_token['identity'] + user_identity = decoded_token[identity_claim] expires = _epoch_utc_to_datetime(decoded_token['exp']) revoked = False diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 65253598..0bab2dc4 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -33,7 +33,7 @@ def jwt_required(fn): def wrapper(*args, **kwargs): jwt_data = _decode_jwt_from_request(request_type='access') ctx_stack.top.jwt = jwt_data - _load_user(jwt_data['identity']) + _load_user(jwt_data[config.identity_claim]) return fn(*args, **kwargs) return wrapper @@ -53,7 +53,7 @@ def wrapper(*args, **kwargs): try: jwt_data = _decode_jwt_from_request(request_type='access') ctx_stack.top.jwt = jwt_data - _load_user(jwt_data['identity']) + _load_user(jwt_data[config.identity_claim]) except NoAuthorizationError: pass return fn(*args, **kwargs) @@ -77,7 +77,7 @@ def wrapper(*args, **kwargs): raise FreshTokenRequired('Fresh token required') ctx_stack.top.jwt = jwt_data - _load_user(jwt_data['identity']) + _load_user(jwt_data[config.identity_claim]) return fn(*args, **kwargs) return wrapper @@ -92,7 +92,7 @@ def jwt_refresh_token_required(fn): def wrapper(*args, **kwargs): jwt_data = _decode_jwt_from_request(request_type='refresh') ctx_stack.top.jwt = jwt_data - _load_user(jwt_data['identity']) + _load_user(jwt_data[config.identity_claim]) return fn(*args, **kwargs) return wrapper diff --git a/tests/test_blacklist.py b/tests/test_blacklist.py index 6fd4a137..4a6749c1 100644 --- a/tests/test_blacklist.py +++ b/tests/test_blacklist.py @@ -16,6 +16,7 @@ def setUp(self): self.app = Flask(__name__) self.app.secret_key = 'super=secret' self.app.config['JWT_BLACKLIST_ENABLED'] = True + self.app.config['JWT_IDENTITY_CLAIM'] = 'sub' self.jwt_manager = JWTManager(self.app) self.client = self.app.test_client() self.blacklist = set() diff --git a/tests/test_jwt_encode_decode.py b/tests/test_jwt_encode_decode.py index d363eda5..2d31884e 100644 --- a/tests/test_jwt_encode_decode.py +++ b/tests/test_jwt_encode_decode.py @@ -30,24 +30,25 @@ def test_encode_access_token(self): algorithm = 'HS256' token_expire_delta = timedelta(minutes=5) user_claims = {'foo': 'bar'} + identity_claim = 'identity' # Check with a fresh token with self.app.test_request_context(): identity = 'user1' token = encode_access_token(identity, secret, algorithm, token_expire_delta, fresh=True, user_claims=user_claims, csrf=False, - identity_claim='identity') + identity_claim=identity_claim) data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) self.assertIn('jti', data) - self.assertIn('identity', data) + self.assertIn(identity_claim, data) self.assertIn('fresh', data) self.assertIn('type', data) self.assertIn('user_claims', data) self.assertNotIn('csrf', data) - self.assertEqual(data['identity'], identity) + self.assertEqual(data[identity_claim], identity) self.assertEqual(data['fresh'], True) self.assertEqual(data['type'], 'access') self.assertEqual(data['user_claims'], user_claims) @@ -61,18 +62,18 @@ def test_encode_access_token(self): identity = 12345 # identity can be anything json serializable token = encode_access_token(identity, secret, algorithm, token_expire_delta, fresh=False, user_claims=user_claims, csrf=True, - identity_claim='identity') + identity_claim=identity_claim) data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) self.assertIn('jti', data) - self.assertIn('identity', data) + self.assertIn(identity_claim, data) self.assertIn('fresh', data) self.assertIn('type', data) self.assertIn('user_claims', data) self.assertIn('csrf', data) - self.assertEqual(data['identity'], identity) + self.assertEqual(data[identity_claim], identity) self.assertEqual(data['fresh'], False) self.assertEqual(data['type'], 'access') self.assertEqual(data['user_claims'], user_claims) @@ -86,16 +87,17 @@ def test_encode_invalid_access_token(self): # Check with non-serializable json with self.app.test_request_context(): user_claims = datetime + identity_claim = 'identity' with self.assertRaises(Exception): encode_access_token('user1', 'secret', 'HS256', timedelta(hours=1), True, user_claims, - csrf=True, identity_claim='identity') + csrf=True, identity_claim=identity_claim) user_claims = {'foo': timedelta(hours=4)} with self.assertRaises(Exception): encode_access_token('user1', 'secret', 'HS256', timedelta(hours=1), True, user_claims, - csrf=True, identity_claim='identity') + csrf=True, identity_claim=identity_claim) def test_encode_refresh_token(self): secret = 'super-totally-secret-key' @@ -212,6 +214,7 @@ def test_decode_jwt(self): def test_decode_invalid_jwt(self): with self.app.test_request_context(): + identity_claim = 'identity' # Verify underlying pyjwt expires verification works with self.assertRaises(jwt.ExpiredSignatureError): token_data = { @@ -219,18 +222,19 @@ def test_decode_invalid_jwt(self): } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') decode_jwt(encoded_token, 'secret', 'HS256', - csrf=False, identity_claim='identity') + csrf=False, identity_claim=identity_claim) # Missing jti with self.assertRaises(JWTDecodeError): + token_data = { 'exp': datetime.utcnow() + timedelta(minutes=5), - 'identity': 'banana', + identity_claim: 'banana', 'type': 'refresh' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') decode_jwt(encoded_token, 'secret', 'HS256', - csrf=False, identity_claim='identity') + csrf=False, identity_claim=identity_claim) # Missing identity with self.assertRaises(JWTDecodeError): @@ -241,61 +245,63 @@ def test_decode_invalid_jwt(self): } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') decode_jwt(encoded_token, 'secret', 'HS256', - csrf=False, identity_claim='identity') + csrf=False, identity_claim=identity_claim) # Non-matching identity claim with self.assertRaises(JWTDecodeError): token_data = { 'exp': datetime.utcnow() + timedelta(minutes=5), - 'identity': 'banana', + identity_claim: 'banana', 'type': 'refresh' } + other_identity_claim = 'sub' encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') + self.assertNotEqual(identity_claim, other_identity_claim) decode_jwt(encoded_token, 'secret', 'HS256', - csrf=False, identity_claim='sub') + csrf=False, identity_claim=other_identity_claim) # Missing type with self.assertRaises(JWTDecodeError): token_data = { 'jti': 'banana', - 'identity': 'banana', + identity_claim: 'banana', 'exp': datetime.utcnow() + timedelta(minutes=5), } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') decode_jwt(encoded_token, 'secret', 'HS256', - csrf=False, identity_claim='identity') + csrf=False, identity_claim=identity_claim) # Missing fresh in access token with self.assertRaises(JWTDecodeError): token_data = { 'jti': 'banana', - 'identity': 'banana', + identity_claim: 'banana', 'exp': datetime.utcnow() + timedelta(minutes=5), 'type': 'access', 'user_claims': {} } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') decode_jwt(encoded_token, 'secret', 'HS256', - csrf=False, identity_claim='identity') + csrf=False, identity_claim=identity_claim) # Missing user claims in access token with self.assertRaises(JWTDecodeError): token_data = { 'jti': 'banana', - 'identity': 'banana', + identity_claim: 'banana', 'exp': datetime.utcnow() + timedelta(minutes=5), 'type': 'access', 'fresh': True } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') decode_jwt(encoded_token, 'secret', 'HS256', - csrf=False, identity_claim='identity') + csrf=False, identity_claim=identity_claim) # Bad token type with self.assertRaises(JWTDecodeError): token_data = { 'jti': 'banana', - 'identity': 'banana', + identity_claim: 'banana', 'exp': datetime.utcnow() + timedelta(minutes=5), 'type': 'banana', 'fresh': True, @@ -303,13 +309,13 @@ def test_decode_invalid_jwt(self): } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') decode_jwt(encoded_token, 'secret', 'HS256', - csrf=False, identity_claim='identity') + csrf=False, identity_claim=identity_claim) # Missing csrf in csrf enabled token with self.assertRaises(JWTDecodeError): token_data = { 'jti': 'banana', - 'identity': 'banana', + identity_claim: 'banana', 'exp': datetime.utcnow() + timedelta(minutes=5), 'type': 'access', 'fresh': True, @@ -317,7 +323,7 @@ def test_decode_invalid_jwt(self): } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') decode_jwt(encoded_token, 'secret', 'HS256', csrf=True, - identity_claim='identity') + identity_claim=identity_claim) def test_create_jwt_with_object(self): # Complex object to test building a JWT from. Normally if you are using diff --git a/tests/test_protected_endpoints.py b/tests/test_protected_endpoints.py index 165ce505..8869fc83 100644 --- a/tests/test_protected_endpoints.py +++ b/tests/test_protected_endpoints.py @@ -22,6 +22,7 @@ def setUp(self): self.app.config['JWT_ALGORITHM'] = 'HS256' self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1) self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1) + self.app.config['JWT_IDENTITY_CLAIM'] = 'sub' self.jwt_manager = JWTManager(self.app) self.client = self.app.test_client() @@ -454,6 +455,9 @@ def claims(): claims_keys = [claim for claim in jwt] return jsonify(claims_keys), 200 + # Grab custom identity claim + identity_claim = self.app.config['JWT_IDENTITY_CLAIM'] + # Login response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) @@ -466,7 +470,7 @@ def claims(): self.assertIn('iat', data) self.assertIn('nbf', data) self.assertIn('jti', data) - self.assertIn('identity', data) + self.assertIn(identity_claim, data) self.assertIn('fresh', data) self.assertIn('type', data) self.assertIn('user_claims', data) @@ -836,12 +840,13 @@ def test_access_endpoints_with_cookie_missing_csrf_field(self): def test_access_endpoints_with_cookie_csrf_claim_not_string(self): now = datetime.utcnow() + identity_claim = self.app.config['JWT_IDENTITY_CLAIM'] token_data = { 'exp': now + timedelta(minutes=5), 'iat': now, 'nbf': now, 'jti': 'banana', - 'identity': 'banana', + identity_claim: 'banana', 'type': 'refresh', 'csrf': 404 }