diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6152740b..c2703aa6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,6 +12,11 @@ Change Log Unreleased ---------- +Added +~~~~~ +* (`#354 `_) Implemented ``verify_jwk_signature_using_keyset`` function. + This function allows for easy verification of JSON Web Key (JWK) signatures using a provided keyset. + [8.9.3] - 2023-09-13 -------------------- diff --git a/edx_rest_framework_extensions/auth/jwt/decoder.py b/edx_rest_framework_extensions/auth/jwt/decoder.py index 8c57fe13..12534eda 100644 --- a/edx_rest_framework_extensions/auth/jwt/decoder.py +++ b/edx_rest_framework_extensions/auth/jwt/decoder.py @@ -193,8 +193,8 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token): # can fully retire code paths for symmetric keys, as part of # DEPR: Symmetric JWTs: https://github.com/openedx/public-engineering/issues/83 - # Use add_symmetric_keys=False to only include asymmetric keys at first - key_set = _get_signing_jwk_key_set(jwt_issuer, add_symmetric_keys=False) + # Pass only asymmetric_keys to only include asymmetric keys at first + key_set = get_verification_jwk_key_set(asymmetric_keys=settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET')) # .. custom_attribute_name: jwt_auth_verify_asymmetric_keys_count # .. custom_attribute_description: Number of JWT verification keys in use for this # verification. Should be same as number of asymmetric public keys. This is @@ -203,7 +203,7 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token): set_custom_attribute('jwt_auth_verify_asymmetric_keys_count', len(key_set)) try: - _verify_jwk_signature_using_keyset(token, key_set, jwt_issuer) + verify_jwk_signature_using_keyset(token, key_set, aud=jwt_issuer['AUDIENCE']) # .. custom_attribute_name: jwt_auth_asymmetric_verified # .. custom_attribute_description: Whether the JWT was successfully verified # using an asymmetric key. @@ -218,7 +218,9 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token): # the asymmetric keys here is redundant and unnecessary, but this code is temporary and # will be simplified once symmetric keys have been fully retired. - key_set = _get_signing_jwk_key_set(jwt_issuer, add_symmetric_keys=decode_symmetric_token) + asymmetric_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET') + secret_key = jwt_issuer['SECRET_KEY'] if decode_symmetric_token else None + key_set = get_verification_jwk_key_set(asymmetric_keys=asymmetric_keys, secret_key=secret_key) # .. custom_attribute_name: jwt_auth_verify_all_keys_count # .. custom_attribute_description: Number of JWT verification keys in use for this # verification. Should be same as number of asymmetric public keys, plus one if @@ -228,7 +230,7 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token): set_custom_attribute('jwt_auth_verify_all_keys_count', len(key_set)) try: - _verify_jwk_signature_using_keyset(token, key_set, jwt_issuer) + verify_jwk_signature_using_keyset(token, key_set, aud=jwt_issuer['AUDIENCE']) # .. custom_attribute_name: jwt_auth_symmetric_verified # .. custom_attribute_description: Whether the JWT was successfully verified # using a symmetric key. @@ -248,7 +250,48 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token): raise jwt.InvalidTokenError(exc_info[2]) from token_error -def _verify_jwk_signature_using_keyset(token, key_set, jwt_issuer): +def verify_jwk_signature_using_keyset(token, key_set, aud=None, iss=None, verify_signature=True, verify_exp=True): + """ + Verifies the signature of a JSON Web Token (JWT) using a provided JSON Web Key (PyJWK) key set. + + Args: + token (str): The JWT to be verified. + key_set (list -> PyJWK): A list containing PyJWKs (JSON Web Keys) + for signature verification. + aud (str or None): The expected "aud" (audience) claim in the JWT. + If provided, the JWT's "aud" claim must match this value for + the verification to succeed. + iss (str or None): The expected "iss" (issuer) claim in the JWT. + If provided, the JWT's "iss" claim must match this value for + the verification to succeed. + verify_signature (bool): Whether to verify the JWT's digital signature. + Set to False if you want to skip signature verification + (e.g., if the JWT is already pre-verified). + verify_exp (bool): Whether to verify the JWT's expiration time ("exp" claim). + Set to False if you want to skip expiration time verification. + + Returns: + data (dict): Decoded JWT. + + Raises: + ValueError: If the token is not a valid JWT or if the key_set is empty + or improperly formatted. + jwt.ExpiredSignatureError: If the JWT has expired and verify_exp + is set to True. + jwt.InvalidIssuerError: If the "iss" claim does not match the expected + issuer and iss is provided. + jwt.InvalidAudienceError: If the "aud" claim does not match the expected + audience and aud is provided. + jwt.DecodeError: If the JWT decoding fails for any reason. + """ + options = { + 'verify_signature': verify_signature, + 'verify_exp': verify_exp, + 'verify_aud': bool(aud), + 'verify_iss': bool(iss) + } + data = None + for i in range(0, len(key_set)): try: algorithms = None @@ -257,16 +300,19 @@ def _verify_jwk_signature_using_keyset(token, key_set, jwt_issuer): elif key_set[i].key_type == 'oct': algorithms = ['HS256',] - _ = jwt.decode( + data = jwt.decode( token, key=key_set[i].key, algorithms=algorithms, - audience=jwt_issuer['AUDIENCE'], + issuer=iss, + audience=aud, + options=options ) break except Exception: # pylint: disable=broad-except if i == len(key_set) - 1: raise + return data def _decode_and_verify_token(token, jwt_issuer): @@ -315,21 +361,19 @@ def _decode_and_verify_token(token, jwt_issuer): return decoded_token -def _get_signing_jwk_key_set(jwt_issuer, add_symmetric_keys=True): +def get_verification_jwk_key_set(asymmetric_keys=None, secret_key=None): """ Returns a JWK Keyset containing all active keys that are configured for verifying signatures. """ key_set = [] - # asymmetric keys - signing_jwk_set = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET') - if signing_jwk_set: - key_set.extend(PyJWKSet.from_json(signing_jwk_set).keys) + if asymmetric_keys: + key_set.extend(PyJWKSet.from_json(asymmetric_keys).keys) - if add_symmetric_keys: + if secret_key: # symmetric key - encoded_secret_key = base64url_encode(jwt_issuer['SECRET_KEY'].encode('utf-8')) + encoded_secret_key = base64url_encode(secret_key.encode('utf-8')) key_set.append(PyJWK({'k': encoded_secret_key, 'kty': 'oct'})) return key_set