Skip to content

Commit

Permalink
feat: make jwt decode function generic
Browse files Browse the repository at this point in the history
  • Loading branch information
mumarkhan999 committed Sep 11, 2023
1 parent 4775c46 commit f1a939b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 15 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Change Log
Unreleased
----------

Added
~~~~~
* (`#354 <https://github.com/openedx/edx-drf-extensions/pull/354>`_) Implemented ``verify_jwk_signature_using_keyset`` function.
This function allows for easy verification of JSON Web Key (JWK) signatures using a provided keyset.

[8.9.2] - 2023-08-31
--------------------

Expand Down
74 changes: 59 additions & 15 deletions edx_rest_framework_extensions/auth/jwt/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
# can fully retire code paths for symmetric keys, as part of
# DEPR: Symmetric JWTs: https://github.com/openedx/public-engineering/issues/83

# Use add_symmetric_keys=False to only include asymmetric keys at first
key_set = _get_signing_jwk_key_set(jwt_issuer, add_symmetric_keys=False)
# Pass only asymmetric_keys to only include asymmetric keys at first
key_set = get_verification_jwk_key_set(asymmetric_keys=settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET'))
# .. custom_attribute_name: jwt_auth_verify_asymmetric_keys_count
# .. custom_attribute_description: Number of JWT verification keys in use for this
# verification. Should be same as number of asymmetric public keys. This is
Expand All @@ -203,7 +203,7 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
set_custom_attribute('jwt_auth_verify_asymmetric_keys_count', len(key_set))

try:
_verify_jwk_signature_using_keyset(token, key_set, jwt_issuer)
verify_jwk_signature_using_keyset(token, key_set, aud=jwt_issuer['AUDIENCE'])
# .. custom_attribute_name: jwt_auth_asymmetric_verified
# .. custom_attribute_description: Whether the JWT was successfully verified
# using an asymmetric key.
Expand All @@ -218,7 +218,9 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
# the asymmetric keys here is redundant and unnecessary, but this code is temporary and
# will be simplified once symmetric keys have been fully retired.

key_set = _get_signing_jwk_key_set(jwt_issuer, add_symmetric_keys=decode_symmetric_token)
asymmetric_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET')
secret_key = jwt_issuer['SECRET_KEY'] if decode_symmetric_token else None
key_set = get_verification_jwk_key_set(asymmetric_keys=asymmetric_keys, secret_key=secret_key)
# .. custom_attribute_name: jwt_auth_verify_all_keys_count
# .. custom_attribute_description: Number of JWT verification keys in use for this
# verification. Should be same as number of asymmetric public keys, plus one if
Expand All @@ -228,7 +230,7 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
set_custom_attribute('jwt_auth_verify_all_keys_count', len(key_set))

try:
_verify_jwk_signature_using_keyset(token, key_set, jwt_issuer)
verify_jwk_signature_using_keyset(token, key_set, aud=jwt_issuer['AUDIENCE'])
# .. custom_attribute_name: jwt_auth_symmetric_verified
# .. custom_attribute_description: Whether the JWT was successfully verified
# using a symmetric key.
Expand All @@ -248,7 +250,48 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
raise jwt.InvalidTokenError(exc_info[2]) from token_error


def _verify_jwk_signature_using_keyset(token, key_set, jwt_issuer):
def verify_jwk_signature_using_keyset(token, key_set, aud=None, iss=None, verify_signature=True, verify_exp=True):
"""
Verifies the signature of a JSON Web Token (JWT) using a provided JSON Web Key (PyJWK) key set.
Args:
token (str): The JWT to be verified.
key_set (list -> PyJWK): A list containing PyJWKs (JSON Web Keys)
for signature verification.
aud (str or None): The expected "aud" (audience) claim in the JWT.
If provided, the JWT's "aud" claim must match this value for
the verification to succeed.
iss (str or None): The expected "iss" (issuer) claim in the JWT.
If provided, the JWT's "iss" claim must match this value for
the verification to succeed.
verify_signature (bool): Whether to verify the JWT's digital signature.
Set to False if you want to skip signature verification
(e.g., if the JWT is already pre-verified).
verify_exp (bool): Whether to verify the JWT's expiration time ("exp" claim).
Set to False if you want to skip expiration time verification.
Returns:
data (dict): Decoded JWT.
Raises:
ValueError: If the token is not a valid JWT or if the key_set is empty
or improperly formatted.
jwt.ExpiredSignatureError: If the JWT has expired and verify_exp
is set to True.
jwt.InvalidIssuerError: If the "iss" claim does not match the expected
issuer and iss is provided.
jwt.InvalidAudienceError: If the "aud" claim does not match the expected
audience and aud is provided.
jwt.DecodeError: If the JWT decoding fails for any reason.
"""
options = {
'verify_signature': verify_signature,
'verify_exp': verify_exp,
'verify_aud': bool(aud),
'verify_iss': bool(iss)
}
data = None

for i in range(0, len(key_set)):
try:
algorithms = None
Expand All @@ -257,16 +300,19 @@ def _verify_jwk_signature_using_keyset(token, key_set, jwt_issuer):
elif key_set[i].key_type == 'oct':
algorithms = ['HS256',]

_ = jwt.decode(
data = jwt.decode(
token,
key=key_set[i].key,
algorithms=algorithms,
audience=jwt_issuer['AUDIENCE'],
issuer=iss,
audience=aud,
options=options
)
break
except Exception: # pylint: disable=broad-except
if i == len(key_set) - 1:
raise
return data


def _decode_and_verify_token(token, jwt_issuer):
Expand Down Expand Up @@ -315,21 +361,19 @@ def _decode_and_verify_token(token, jwt_issuer):
return decoded_token


def _get_signing_jwk_key_set(jwt_issuer, add_symmetric_keys=True):
def get_verification_jwk_key_set(asymmetric_keys=None, secret_key=None):
"""
Returns a JWK Keyset containing all active keys that are configured
for verifying signatures.
"""
key_set = []

# asymmetric keys
signing_jwk_set = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET')
if signing_jwk_set:
key_set.extend(PyJWKSet.from_json(signing_jwk_set).keys)
if asymmetric_keys:
key_set.extend(PyJWKSet.from_json(asymmetric_keys).keys)

if add_symmetric_keys:
if secret_key:
# symmetric key
encoded_secret_key = base64url_encode(jwt_issuer['SECRET_KEY'].encode('utf-8'))
encoded_secret_key = base64url_encode(secret_key.encode('utf-8'))
key_set.append(PyJWK({'k': encoded_secret_key, 'kty': 'oct'}))

return key_set

0 comments on commit f1a939b

Please sign in to comment.