Skip to content

Commit

Permalink
fix pyjwt decode secret type complaints
Browse files Browse the repository at this point in the history
  • Loading branch information
liwen authored and cofin committed Aug 8, 2024
1 parent 07ec1cd commit e76d3fe
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
33 changes: 32 additions & 1 deletion litestar/security/jwt/token.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

import base64
import dataclasses
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any

import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec, ed448, ed25519, rsa
from jwt import InvalidTokenError, PyJWTError

from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException

if TYPE_CHECKING:
from jwt.algorithms import AllowedPublicKeys
from typing_extensions import Self


Expand Down Expand Up @@ -86,8 +90,35 @@ def decode(cls, encoded_token: str, secret: str | dict[str, str], algorithm: str
Raises:
NotAuthorizedException: If the token is invalid.
"""

def base64url_decode(code: str) -> bytes:
padding = "=" * (4 - (len(code) % 4))
return base64.urlsafe_b64decode(code + padding)

if isinstance(secret, str):
converted_secret: AllowedPublicKeys | str | bytes = secret
else:
if secret["kty"] == "RSA":
n = int.from_bytes(base64url_decode(secret["n"]), byteorder="big")
e = int.from_bytes(base64url_decode(secret["e"]), byteorder="big")
converted_secret = rsa.RSAPublicNumbers(e, n).public_key(default_backend())
elif secret["kty"] == "EC":
x = int.from_bytes(base64url_decode(secret["x"]), byteorder="big")
y = int.from_bytes(base64url_decode(secret["y"]), byteorder="big")
converted_secret = ec.EllipticCurvePublicNumbers(x, y, ec.SECP256R1()).public_key(default_backend())
elif secret["kty"] == "OKP" and secret["crv"] == "Ed25519":
x = base64url_decode(secret["x"])
converted_secret = ed25519.Ed25519PublicKey.from_public_bytes(x)
elif secret["kty"] == "OKP" and secret["crv"] == "Ed448":
x = base64url_decode(secret["x"])
converted_secret = ed448.Ed448PublicKey.from_public_bytes(x)
else:
raise TypeError("The secret is not a form of allowed public key.")

try:
payload = jwt.decode(jwt=encoded_token, key=secret, algorithms=[algorithm], options={"verify_aud": False})
payload = jwt.decode(
jwt=encoded_token, key=converted_secret, algorithms=[algorithm], options={"verify_aud": False}
)
exp = datetime.fromtimestamp(payload.pop("exp"), tz=timezone.utc)
iat = datetime.fromtimestamp(payload.pop("iat"), tz=timezone.utc)
field_names = {f.name for f in dataclasses.fields(Token)}
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ dependencies = [
"click",
"rich>=13.0.0",
"rich-click",
"pyjwt>=2.8.0",
]
description = "Litestar - A production-ready, highly performant, extensible ASGI API Framework"
keywords = ["api", "rest", "asgi", "litestar", "starlite"]
Expand Down

0 comments on commit e76d3fe

Please sign in to comment.