diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index 7672e17a..d49ae1d8 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -11,122 +11,124 @@ from flask_jwt_extended.config import config -def default_user_claims_callback(userdata): - """ - By default, we add no additional claims to the access tokens. - - :param userdata: data passed in as the ```identity``` argument to the - ```create_access_token``` and ```create_refresh_token``` - functions - """ - return {} - - -def default_jwt_headers_callback(default_headers): - """ - By default header typically consists of two parts: the type of the token, - which is JWT, and the signing algorithm being used, such as HMAC SHA256 - or RSA. But we don't set the default header here we set it as empty which - further by default set while encoding the token - :return: default we set None here - """ - return None - - -def default_user_identity_callback(userdata): - """ - By default, we use the passed in object directly as the jwt identity. - See this for additional info: - - :param userdata: data passed in as the ```identity``` argument to the - ```create_access_token``` and ```create_refresh_token``` - functions - """ - return userdata - - -def default_expired_token_callback(_expired_jwt_header, _expired_jwt_data): - """ - By default, if an expired token attempts to access a protected endpoint, - we return a generic error message with a 401 status - """ - return jsonify({config.error_msg_key: "Token has expired"}), 401 - - -def default_invalid_token_callback(error_string): - """ - By default, if an invalid token attempts to access a protected endpoint, we - return the error string for why it is not valid with a 422 status code - - :param error_string: String indicating why the token is invalid - """ - return jsonify({config.error_msg_key: error_string}), 422 - - -def default_unauthorized_callback(error_string): - """ - By default, if a protected endpoint is accessed without a JWT, we return - the error string indicating why this is unauthorized, with a 401 status code - - :param error_string: String indicating why this request is unauthorized - """ - return jsonify({config.error_msg_key: error_string}), 401 - - -def default_needs_fresh_token_callback(jwt_header, jwt_data): - """ - By default, if a non-fresh jwt is used to access a ```fresh_jwt_required``` - endpoint, we return a general error message with a 401 status code - """ - return jsonify({config.error_msg_key: "Fresh token required"}), 401 - - -def default_revoked_token_callback(): - """ - By default, if a revoked token is used to access a protected endpoint, we - return a general error message with a 401 status code - """ - return jsonify({config.error_msg_key: "Token has been revoked"}), 401 - - -def default_user_lookup_error_callback(_jwt_header, jwt_data): - """ - By default, if a user_lookup callback is defined and the callback - function returns None, we return a general error message with a 401 - status code - """ - identity = jwt_data[config.identity_claim_key] - result = {config.error_msg_key: "Error loading the user {}".format(identity)} - return jsonify(result), 401 - - -# TODO: Change this to default_token_verification_callback, pass in header and data. -def default_claims_verification_callback(user_claims): - """ - By default, we do not do any verification of the user claims. - """ - return True - - -def default_verify_claims_failed_callback(_jwt_header, _jwt_data): - """ - By default, if the user claims verification failed, we return a generic - error message with a 400 status code - """ - return jsonify({config.error_msg_key: "User claims verification failed"}), 400 - - -def default_decode_key_callback(claims, headers): - """ - By default, the decode key specified via the JWT_SECRET_KEY or - JWT_PUBLIC_KEY settings will be used to decode all tokens - """ - return config.decode_key - - -def default_encode_key_callback(identity): - """ - By default, the encode key specified via the JWT_SECRET_KEY or - JWT_PRIVATE_KEY settings will be used to encode all tokens - """ - return config.encode_key +class DefaultCallbacks: + @staticmethod + def user_claims(userdata): + """ + By default, we add no additional claims to the access tokens. + + :param userdata: data passed in as the ```identity``` argument to the + ```create_access_token``` and ```create_refresh_token``` + functions + """ + return {} + + @staticmethod + def jwt_headers(default_headers): + """ + By default header typically consists of two parts: the type of the token, + which is JWT, and the signing algorithm being used, such as HMAC SHA256 + or RSA. But we don't set the default header here we set it as empty which + further by default set while encoding the token + :return: default we set None here + """ + return None + + @staticmethod + def user_identity(userdata): + """ + By default, we use the passed in object directly as the jwt identity. + See this for additional info: + + :param userdata: data passed in as the ```identity``` argument to the + ```create_access_token``` and ```create_refresh_token``` + functions + """ + return userdata + + @staticmethod + def expired_token_response(_expired_jwt_header, _expired_jwt_data): + """ + By default, if an expired token attempts to access a protected endpoint, + we return a generic error message with a 401 status + """ + return jsonify({config.error_msg_key: "Token has expired"}), 401 + + @staticmethod + def invalid_token_response(error_string): + """ + By default, if an invalid token attempts to access a protected endpoint, we + return the error string for why it is not valid with a 422 status code + + :param error_string: String indicating why the token is invalid + """ + return jsonify({config.error_msg_key: error_string}), 422 + + @staticmethod + def unauthorized_response(error_string): + """ + By default, if a protected endpoint is accessed without a JWT, we return + the error string indicating why this is unauthorized, with a 401 status code + + :param error_string: String indicating why this request is unauthorized + """ + return jsonify({config.error_msg_key: error_string}), 401 + + @staticmethod + def needs_fresh_token_response(jwt_header, jwt_data): + """ + By default, if a non-fresh jwt is used to access a ```fresh_jwt_required``` + endpoint, we return a general error message with a 401 status code + """ + return jsonify({config.error_msg_key: "Fresh token required"}), 401 + + @staticmethod + def revoked_token_response(): + """ + By default, if a revoked token is used to access a protected endpoint, we + return a general error message with a 401 status code + """ + return jsonify({config.error_msg_key: "Token has been revoked"}), 401 + + @staticmethod + def user_lookup_error_response(_jwt_header, jwt_data): + """ + By default, if a user_lookup callback is defined and the callback + function returns None, we return a general error message with a 401 + status code + """ + identity = jwt_data[config.identity_claim_key] + result = {config.error_msg_key: "Error loading the user {}".format(identity)} + return jsonify(result), 401 + + # TODO: Change this to default_token_verification_callback, pass in header and data. + @staticmethod + def verify_claims(user_claims): + """ + By default, we do not do any verification of the user claims. + """ + return True + + @staticmethod + def invalid_claims_response(_jwt_header, _jwt_data): + """ + By default, if the user claims verification failed, we return a generic + error message with a 400 status code + """ + return jsonify({config.error_msg_key: "User claims verification failed"}), 400 + + @staticmethod + def decode_key(claims, headers): + """ + By default, the decode key specified via the JWT_SECRET_KEY or + JWT_PUBLIC_KEY settings will be used to decode all tokens + """ + return config.decode_key + + @staticmethod + def encode_key(identity): + """ + By default, the encode key specified via the JWT_SECRET_KEY or + JWT_PRIVATE_KEY settings will be used to encode all tokens + """ + return config.encode_key diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index 110dbea3..59bdab24 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -8,19 +8,7 @@ from jwt import InvalidTokenError from flask_jwt_extended.config import config -from flask_jwt_extended.default_callbacks import default_claims_verification_callback -from flask_jwt_extended.default_callbacks import default_decode_key_callback -from flask_jwt_extended.default_callbacks import default_encode_key_callback -from flask_jwt_extended.default_callbacks import default_expired_token_callback -from flask_jwt_extended.default_callbacks import default_invalid_token_callback -from flask_jwt_extended.default_callbacks import default_jwt_headers_callback -from flask_jwt_extended.default_callbacks import default_needs_fresh_token_callback -from flask_jwt_extended.default_callbacks import default_revoked_token_callback -from flask_jwt_extended.default_callbacks import default_unauthorized_callback -from flask_jwt_extended.default_callbacks import default_user_claims_callback -from flask_jwt_extended.default_callbacks import default_user_identity_callback -from flask_jwt_extended.default_callbacks import default_user_lookup_error_callback -from flask_jwt_extended.default_callbacks import default_verify_claims_failed_callback +from flask_jwt_extended.default_callbacks import DefaultCallbacks from flask_jwt_extended.exceptions import CSRFError from flask_jwt_extended.exceptions import FreshTokenRequired from flask_jwt_extended.exceptions import InvalidHeaderError @@ -34,7 +22,7 @@ from flask_jwt_extended.tokens import _encode_jwt -class JWTManager(object): +class JWTManager(DefaultCallbacks): """ An object used to hold JWT settings and callback functions for the Flask-JWT-Extended extension. @@ -44,6 +32,11 @@ class JWTManager(object): to your app in a factory function. """ + # register the default error handler callback methods. these can be + # overridden with the appropriate loader decorators + token_is_blacklisted = None + lookup_user = None + def __init__(self, app=None): """ Create the JWTManager instance. You can either pass a flask application @@ -52,24 +45,6 @@ def __init__(self, app=None): :param app: A flask application """ - # Register the default error handler callback methods. These can be - # overridden with the appropriate loader decorators - self._claims_verification_callback = default_claims_verification_callback - self._decode_key_callback = default_decode_key_callback - self._encode_key_callback = default_encode_key_callback - self._expired_token_callback = default_expired_token_callback - self._invalid_token_callback = default_invalid_token_callback - self._jwt_additional_header_callback = default_jwt_headers_callback - self._needs_fresh_token_callback = default_needs_fresh_token_callback - self._revoked_token_callback = default_revoked_token_callback - self._token_in_blacklist_callback = None - self._unauthorized_callback = default_unauthorized_callback - self._user_claims_callback = default_user_claims_callback - self._user_identity_callback = default_user_identity_callback - self._user_lookup_callback = None - self._user_lookup_error_callback = default_user_lookup_error_callback - self._verify_claims_failed_callback = default_verify_claims_failed_callback - # Register this extension with the flask app now (if it is provided) if app is not None: self.init_app(app) @@ -96,59 +71,59 @@ def _set_error_handler_callbacks(self, app): @app.errorhandler(CSRFError) def handle_csrf_error(e): - return self._unauthorized_callback(str(e)) + return self.unauthorized_response(str(e)) @app.errorhandler(DecodeError) def handle_decode_error(e): - return self._invalid_token_callback(str(e)) + return self.invalid_token_response(str(e)) @app.errorhandler(ExpiredSignatureError) def handle_expired_error(e): - return self._expired_token_callback(e.jwt_header, e.jwt_data) + return self.expired_token_response(e.jwt_header, e.jwt_data) @app.errorhandler(FreshTokenRequired) def handle_fresh_token_required(e): - return self._needs_fresh_token_callback(e.jwt_header, e.jwt_data) + return self.needs_fresh_token_response(e.jwt_header, e.jwt_data) @app.errorhandler(InvalidAudienceError) def handle_invalid_audience_error(e): - return self._invalid_token_callback(str(e)) + return self.invalid_token_response(str(e)) @app.errorhandler(InvalidIssuerError) def handle_invalid_issuer_error(e): - return self._invalid_token_callback(str(e)) + return self.invalid_token_response(str(e)) @app.errorhandler(InvalidHeaderError) def handle_invalid_header_error(e): - return self._invalid_token_callback(str(e)) + return self.invalid_token_response(str(e)) @app.errorhandler(InvalidTokenError) def handle_invalid_token_error(e): - return self._invalid_token_callback(str(e)) + return self.invalid_token_response(str(e)) @app.errorhandler(JWTDecodeError) def handle_jwt_decode_error(e): - return self._invalid_token_callback(str(e)) + return self.invalid_token_response(str(e)) @app.errorhandler(NoAuthorizationError) def handle_auth_error(e): - return self._unauthorized_callback(str(e)) + return self.unauthorized_response(str(e)) @app.errorhandler(RevokedTokenError) def handle_revoked_token_error(e): - return self._revoked_token_callback() + return self.revoked_token_response() @app.errorhandler(UserClaimsVerificationError) def handle_failed_user_claims_verification(e): - return self._verify_claims_failed_callback(e.jwt_header, e.jwt_data) + return self.invalid_claims_response(e.jwt_header, e.jwt_data) @app.errorhandler(UserLookupError) def handler_user_lookup_error(e): - return self._user_lookup_error_callback(e.jwt_header, e.jwt_data) + return self.user_lookup_error_response(e.jwt_header, e.jwt_data) @app.errorhandler(WrongTokenError) def handle_wrong_token_error(e): - return self._invalid_token_callback(str(e)) + return self.invalid_token_response(str(e)) @staticmethod def _set_default_configuration_options(app): @@ -212,7 +187,7 @@ def additional_headers_loader(self, callback): claims you want included in the access tokens. This returned claims must be *JSON serializable*. """ - self._jwt_additional_header_callback = callback + self.jwt_headers = callback return callback def claims_verification_failed_loader(self, callback): @@ -227,7 +202,7 @@ def claims_verification_failed_loader(self, callback): *HINT*: This callback must be a function that takes **no** arguments, and returns a *Flask response*. """ - self._verify_claims_failed_callback = callback + self.invalid_claims_response = callback return callback def claims_verification_loader(self, callback): @@ -243,7 +218,7 @@ def claims_verification_loader(self, callback): custom claims (python dict) present in the JWT, and returns *`True`* if the claims are valid, or *`False`* otherwise. """ - self._claims_verification_callback = callback + self.verify_claims = callback return callback def decode_key_loader(self, callback): @@ -260,7 +235,7 @@ def decode_key_loader(self, callback): (dictionaries). The function must return a *string* which is the decode key in PEM format to verify the token. """ - self._decode_key_callback = callback + self.decode_key = callback return callback def encode_key_loader(self, callback): @@ -277,7 +252,7 @@ def encode_key_loader(self, callback): or create_refresh_token functions, and must return a *string* which is the decode key to verify the token. """ - self._encode_key_callback = callback + self.encode_key = callback return callback def expired_token_loader(self, callback): @@ -292,7 +267,7 @@ def expired_token_loader(self, callback): which is a dictionary containing the data for the expired token, and and returns a *Flask response*. """ - self._expired_token_callback = callback + self.expired_token_response = callback return callback def invalid_token_loader(self, callback): @@ -307,7 +282,7 @@ def invalid_token_loader(self, callback): a string which contains the reason why a token is invalid, and returns a *Flask response*. """ - self._invalid_token_callback = callback + self.invalid_token_response = callback return callback def needs_fresh_token_loader(self, callback): @@ -322,7 +297,7 @@ def needs_fresh_token_loader(self, callback): *HINT*: The callback must be a function that takes **no** arguments, and returns a *Flask response*. """ - self._needs_fresh_token_callback = callback + self.needs_fresh_token_response = callback return callback def revoked_token_loader(self, callback): @@ -336,7 +311,7 @@ def revoked_token_loader(self, callback): *HINT*: The callback must be a function that takes **no** arguments, and returns a *Flask response*. """ - self._revoked_token_callback = callback + self.revoked_token_response = callback return callback def token_in_blacklist_loader(self, callback): @@ -350,7 +325,7 @@ def token_in_blacklist_loader(self, callback): has been blacklisted (or is otherwise considered revoked), or *`False`* otherwise. """ - self._token_in_blacklist_callback = callback + self.token_is_blacklisted = callback return callback def unauthorized_loader(self, callback): @@ -365,7 +340,7 @@ def unauthorized_loader(self, callback): a string which contains the reason why a JWT could not be found, and returns a *Flask response*. """ - self._unauthorized_callback = callback + self.unauthorized_response = callback return callback def user_claims_loader(self, callback): @@ -380,7 +355,7 @@ def user_claims_loader(self, callback): claims you want included in the access tokens. This returned claims must be *JSON serializable*. """ - self._user_claims_callback = callback + self.user_claims = callback return callback def user_identity_loader(self, callback): @@ -398,7 +373,7 @@ def user_identity_loader(self, callback): :func:`~flask_jwt_extended.create_refresh_token`, and returns the *JSON serializable* identity of this token. """ - self._user_identity_callback = callback + self.user_identity = callback return callback def user_lookup_loader(self, callback): @@ -415,7 +390,7 @@ def user_lookup_loader(self, callback): `None`, the :meth:`~flask_jwt_extended.JWTManager.user_lookup_error_loader` will be called. """ - self._user_lookup_callback = callback + self.lookup_user = callback return callback def user_lookup_error_loader(self, callback): @@ -431,7 +406,7 @@ def user_lookup_error_loader(self, callback): *HINT*: The callback must be a function that takes **one** argument, which is the identity of the user who failed to load, and must return a *Flask response*. """ - self._user_lookup_error_callback = callback + self.user_lookup_error_response = callback return callback def _encode_jwt_from_config( @@ -447,10 +422,10 @@ def _encode_jwt_from_config( expires_delta = config.refresh_expires if headers is None: - headers = self._jwt_additional_header_callback(identity) + headers = self.jwt_headers(identity) if token_type == "access" or config.user_claims_in_refresh_token: - claim_overrides = self._user_claims_callback(identity) + claim_overrides = self.user_claims(identity) else: claim_overrides = {} @@ -464,10 +439,10 @@ def _encode_jwt_from_config( expires_delta=expires_delta, fresh=fresh, headers=headers, - identity=self._user_identity_callback(identity), + identity=self.user_identity(identity), identity_claim_key=config.identity_claim_key, json_encoder=config.json_encoder, - secret=self._encode_key_callback(identity), + secret=self.encode_key(identity), token_type=token_type, ) @@ -478,7 +453,7 @@ def _decode_jwt_from_config( encoded_token, verify=False, algorithms=config.decode_algorithms ) unverified_headers = jwt.get_unverified_header(encoded_token) - secret = self._decode_key_callback(unverified_claims, unverified_headers) + secret = self.decode_key(unverified_claims, unverified_headers) kwargs = { "algorithms": config.decode_algorithms, diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 766a5392..de2de43f 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -172,22 +172,22 @@ def create_refresh_token(identity, expires_delta=None, user_claims=None, headers def has_user_lookup(): jwt_manager = _get_jwt_manager() - return jwt_manager._user_lookup_callback is not None + return jwt_manager.lookup_user is not None def user_lookup(*args, **kwargs): jwt_manager = _get_jwt_manager() - return jwt_manager._user_lookup_callback(*args, **kwargs) + return jwt_manager.lookup_user(*args, **kwargs) def has_token_in_blacklist_callback(): jwt_manager = _get_jwt_manager() - return jwt_manager._token_in_blacklist_callback is not None + return jwt_manager.token_is_blacklisted is not None def token_in_blacklist(*args, **kwargs): jwt_manager = _get_jwt_manager() - return jwt_manager._token_in_blacklist_callback(*args, **kwargs) + return jwt_manager.token_is_blacklisted(*args, **kwargs) def verify_token_type(decoded_token, expected_type): @@ -214,7 +214,7 @@ def verify_token_not_blacklisted(decoded_token, request_type): def _verify_token_claims(jwt_header, jwt_data): jwt_manager = _get_jwt_manager() - if not jwt_manager._claims_verification_callback(jwt_data): + if not jwt_manager.verify_claims(jwt_data): error_msg = "User claims verification failed" raise UserClaimsVerificationError(error_msg, jwt_header, jwt_data) diff --git a/tests/test_decode_tokens.py b/tests/test_decode_tokens.py index 133d0898..8cb2702c 100644 --- a/tests/test_decode_tokens.py +++ b/tests/test_decode_tokens.py @@ -165,20 +165,20 @@ def test_encode_decode_callback_values(app, default_access_token): jwtM = get_jwt_manager(app) app.config["JWT_SECRET_KEY"] = "foobarbaz" with app.test_request_context(): - assert jwtM._decode_key_callback({}, {}) == "foobarbaz" - assert jwtM._encode_key_callback({}) == "foobarbaz" + assert jwtM.decode_key({}, {}) == "foobarbaz" + assert jwtM.encode_key({}) == "foobarbaz" @jwtM.encode_key_loader def get_encode_key_1(identity): return "different secret" - assert jwtM._encode_key_callback("") == "different secret" + assert jwtM.encode_key("") == "different secret" @jwtM.decode_key_loader def get_decode_key_1(claims, headers): return "different secret" - assert jwtM._decode_key_callback({}, {}) == "different secret" + assert jwtM.decode_key({}, {}) == "different secret" def test_custom_encode_decode_key_callbacks(app, default_access_token): diff --git a/tests/test_oop_callbacks.py b/tests/test_oop_callbacks.py new file mode 100644 index 00000000..3d24c06d --- /dev/null +++ b/tests/test_oop_callbacks.py @@ -0,0 +1,38 @@ +import pytest +from flask import Flask +from flask import jsonify + +from flask_jwt_extended import create_access_token +from flask_jwt_extended import get_jwt +from flask_jwt_extended import jwt_required +from flask_jwt_extended import JWTManager as JWTManager_ +from tests.utils import make_headers + + +class JWTManager(JWTManager_): + def user_claims(self, identity): + return {"foo": "bar"} + + +@pytest.fixture(scope="function") +def app(): + app = Flask(__name__) + app.config["JWT_SECRET_KEY"] = "foobarbaz" + JWTManager(app) + + @app.route("/protected", methods=["GET"]) + @jwt_required() + def get_claims(): + return jsonify(get_jwt()) + + return app + + +def test_user_claim_in_access_token(app): + with app.test_request_context(): + access_token = create_access_token("username") + + test_client = app.test_client() + response = test_client.get("/protected", headers=make_headers(access_token)) + assert response.get_json()["foo"] == "bar" + assert response.status_code == 200