diff --git a/litestar/security/jwt/token.py b/litestar/security/jwt/token.py index de9cac78d8..e8317a6bee 100644 --- a/litestar/security/jwt/token.py +++ b/litestar/security/jwt/token.py @@ -1,16 +1,20 @@ from __future__ import annotations +import base64 import dataclasses from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from typing import TYPE_CHECKING, Any import jwt +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec, ed448, ed25519, rsa from jwt import InvalidTokenError, PyJWTError from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException if TYPE_CHECKING: + from jwt.algorithms import AllowedPublicKeys from typing_extensions import Self @@ -86,8 +90,35 @@ def decode(cls, encoded_token: str, secret: str | dict[str, str], algorithm: str Raises: NotAuthorizedException: If the token is invalid. """ + + def base64url_decode(code: str) -> bytes: + padding = "=" * (4 - (len(code) % 4)) + return base64.urlsafe_b64decode(code + padding) + + if isinstance(secret, str): + converted_secret: AllowedPublicKeys | str | bytes = secret + else: + if secret["kty"] == "RSA": + n = int.from_bytes(base64url_decode(secret["n"]), byteorder="big") + e = int.from_bytes(base64url_decode(secret["e"]), byteorder="big") + converted_secret = rsa.RSAPublicNumbers(e, n).public_key(default_backend()) + elif secret["kty"] == "EC": + x = int.from_bytes(base64url_decode(secret["x"]), byteorder="big") + y = int.from_bytes(base64url_decode(secret["y"]), byteorder="big") + converted_secret = ec.EllipticCurvePublicNumbers(x, y, ec.SECP256R1()).public_key(default_backend()) + elif secret["kty"] == "OKP" and secret["crv"] == "Ed25519": + x = base64url_decode(secret["x"]) + converted_secret = ed25519.Ed25519PublicKey.from_public_bytes(x) + elif secret["kty"] == "OKP" and secret["crv"] == "Ed448": + x = base64url_decode(secret["x"]) + converted_secret = ed448.Ed448PublicKey.from_public_bytes(x) + else: + raise TypeError("The secret is not a form of allowed public key.") + try: - payload = jwt.decode(jwt=encoded_token, key=secret, algorithms=[algorithm], options={"verify_aud": False}) + payload = jwt.decode( + jwt=encoded_token, key=converted_secret, algorithms=[algorithm], options={"verify_aud": False} + ) exp = datetime.fromtimestamp(payload.pop("exp"), tz=timezone.utc) iat = datetime.fromtimestamp(payload.pop("iat"), tz=timezone.utc) field_names = {f.name for f in dataclasses.fields(Token)} diff --git a/pyproject.toml b/pyproject.toml index 2a7c6402a5..0c04941b56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ dependencies = [ "click", "rich>=13.0.0", "rich-click", - "pyjwt>=2.8.0", ] description = "Litestar - A production-ready, highly performant, extensible ASGI API Framework" keywords = ["api", "rest", "asgi", "litestar", "starlite"]