Skip to content

Commit

Permalink
chore: reafactor of jwk schema, minor code cleanup (#262)
Browse files Browse the repository at this point in the history
* chore: refactor of jwk schemas

* chore: rm unused code (oauth2-par)
  • Loading branch information
Zicchio authored Sep 6, 2024
1 parent 28cd720 commit a3662a4
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 39 deletions.
2 changes: 1 addition & 1 deletion pyeudiw/federation/schemas/entity_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pyeudiw.federation.schemas.federation_entity import FederationEntity
from pyeudiw.federation.schemas.wallet_relying_party import WalletRelyingParty
from pyeudiw.jwk.schema import JwksSchema
from pyeudiw.jwk.schemas.public import JwksSchema
from pyeudiw.tools.schema_utils import check_algorithm


Expand Down
2 changes: 1 addition & 1 deletion pyeudiw/federation/schemas/federation_configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, HttpUrl
from pyeudiw.federation.schemas.wallet_relying_party import SigningAlgValuesSupported
from pyeudiw.jwk.schemas.jwk import JwkSchema
from pyeudiw.jwk.schemas.public import JwkSchema


class FederationEntityMetadata(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion pyeudiw/federation/schemas/wallet_relying_party.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from typing import List
from pyeudiw.jwk.schemas.jwk import JwksSchema
from pyeudiw.jwk.schemas.public import JwksSchema
from pydantic import BaseModel, HttpUrl, PositiveInt
from pyeudiw.openid4vp.schemas.vp_formats import VpFormats
from pyeudiw.presentation_exchange.schemas.oid4vc_presentation_definition import PresentationDefinition
Expand Down
29 changes: 0 additions & 29 deletions pyeudiw/jwk/schemas/jwk.py

This file was deleted.

120 changes: 120 additions & 0 deletions pyeudiw/jwk/schemas/public.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Annotated, List, Literal, Optional, Union

from pydantic import BaseModel, Field, field_validator

_SUPPORTED_KTY = Literal["EC", "RSA"]

_SUPPORTED_ALGS = Literal[
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
"RS256",
"RS384",
"RS512",
]

_SUPPORTED_ALG_BY_KTY = {
"RSA": ("PS256", "PS384", "PS512", "RS256", "RS384", "RS512"),
"EC": ("ES256", "ES384", "ES512")
}

# TODO: supported alg by kty and use

_SUPPORTED_CRVS = Literal[
"P-256",
"P-384",
"P-521",
"brainpoolP256r1",
"brainpoolP384r1",
"brainpoolP512r1"
]


class JwkBaseModel(BaseModel):
kid: Optional[str]
use: Optional[Literal["sig", "enc"]] = None


class ECJwkSchema(JwkBaseModel):
kty: Literal["EC"]
crv: _SUPPORTED_CRVS
x: str
y: str


class RSAJwkSchema(JwkBaseModel):
kty: Literal["RSA"]
n: str
e: str


class JwkSchema(BaseModel):
kid: str # Base64url-encoded thumbprint string
kty: _SUPPORTED_KTY
alg: Annotated[Union[_SUPPORTED_ALGS, None], Field(validate_default=True)] = None
use: Annotated[Union[Literal["sig", "enc"], None], Field(validate_default=True)] = None
n: Annotated[Union[str, None], Field(validate_default=True)] = None # Base64urlUInt-encoded
e: Annotated[Union[str, None], Field(validate_default=True)] = None # Base64urlUInt-encoded
x: Annotated[Union[str, None], Field(validate_default=True)] = None # Base64urlUInt-encoded
y: Annotated[Union[str, None], Field(validate_default=True)] = None # Base64urlUInt-encoded
crv: Annotated[Union[_SUPPORTED_CRVS, None], Field(validate_default=True)] = None

def _must_specific_kty_only(v, exp_kty: _SUPPORTED_ALGS, v_name: str, values: dict):
"""validate a jwk parameter by that it is (1) defined and (2) mandatory
only for one specific kty by checking that it is indeed defined by when
kty matches.
"""
err_msg = f"{v_name} must be present only for kty = {exp_kty}"
obt_kty: Union[_SUPPORTED_KTY, None] = values.get("kty", None)
if obt_kty is None:
if v is not None:
raise ValueError("unexpected validation state: missing kty")
return
if exp_kty == obt_kty:
if v is None:
raise ValueError(err_msg)
return
# in this validation v should NOT be defined if obt_kty != exp_kty
if v is not None:
raise ValueError(err_msg)
return

@field_validator("alg")
def validate_alg(cls, v, values):
if v is None:
return
kty = values.data.get("kty")
if v not in _SUPPORTED_ALG_BY_KTY[kty]:
raise ValueError(f"alg value {v} is not compatible or not supported with kty {kty}")
return

@field_validator("n")
def validate_n(cls, v, values):
cls._must_specific_kty_only(v, "RSA", "n", values.data)

@field_validator("e")
def valisate_e(cls, v, values):
cls._must_specific_kty_only(v, "RSA", "e", values.data)

@field_validator("x")
def validate_x(cls, v, values):
cls._must_specific_kty_only(v, "EC", "x", values.data)

@field_validator("y")
def validate_y(cls, v, values):
cls._must_specific_kty_only(v, "EC", "y", values.data)

@field_validator("crv")
def validate_crv(cls, v, values):
cls._must_specific_kty_only(v, "EC", "crv", values.data)


_JwkSchema_T = Annotated[Union[ECJwkSchema, RSAJwkSchema],
Field(discriminator="kty")]


class JwksSchema(BaseModel):
keys: List[_JwkSchema_T]
2 changes: 1 addition & 1 deletion pyeudiw/oauth2/dpop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import uuid

from pyeudiw.jwk.schema import JwkSchema
from pyeudiw.jwk.schemas.public import JwkSchema
from pyeudiw.oauth2.dpop.exceptions import (
InvalidDPoP,
InvalidDPoPAth,
Expand Down
2 changes: 1 addition & 1 deletion pyeudiw/oauth2/dpop/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel, HttpUrl

from pyeudiw.jwk.schemas.jwk import JwkSchema
from pyeudiw.jwk.schemas.public import JwkSchema


class DPoPTokenHeaderSchema(BaseModel):
Expand Down
Empty file removed pyeudiw/oauth2/par/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion pyeudiw/openid4vp/schemas/cnf_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel

from pyeudiw.jwk.schema import JwkSchema
from pyeudiw.jwk.schemas.public import JwkSchema


class CNFSchema(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion pyeudiw/satosa/schemas/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel
from pyeudiw.jwk.schemas.jwk import JwkSchema
from pyeudiw.jwk.schemas.public import JwkSchema
from pyeudiw.satosa.schemas.endpoint import EndpointsConfig
from pyeudiw.satosa.schemas.qrcode import QRCode
from pyeudiw.satosa.schemas.autorization import AuthorizationConfig
Expand Down
2 changes: 1 addition & 1 deletion pyeudiw/sd_jwt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, HttpUrl

from pyeudiw.jwk.schema import JwkSchema
from pyeudiw.jwk.schemas.public import JwkSchema

SD_JWT_REGEXP = r"^(([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*)(~([-A-Za-z0-9\=_\.])*)*$"

Expand Down
96 changes: 96 additions & 0 deletions pyeudiw/tests/jwk/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from pyeudiw.jwk.schemas.public import JwkSchema, ECJwkSchema, RSAJwkSchema


def test_valid_rsa_jwk():
jwk_d = {
"kty": "RSA",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB",
"alg": "RS256",
"kid": "2011-04-29",
}
JwkSchema(**jwk_d)
RSAJwkSchema(**jwk_d)


def test_valid_ec_jwk():
jwk_d = {
"kty": "EC",
"crv": "P-256",
"x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
"y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
"use": "enc",
"kid": "1",
}
JwkSchema(**jwk_d)
ECJwkSchema(**jwk_d)


def test_invalid_keys():
# table with keys that should fail jwk parsing
bad_keys_table: list[tuple[dict, str]] = [
(
{
"aaaa": "1"
},
"non-sense key"
),
(
{
"kty": "RSA",
"e": "AQAB",
"alg": "RS256",
"kid": "2011-04-29",
},
"rsa key with missing attribute [n]"
),
(
{
"kty": "RSA",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB",
"x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
"alg": "RS256",
"kid": "2011-04-29",
},
"rsa key with unexpected attribute [x]"
),
(
{
"kty": "EC",
"x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
"y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
"use": "enc",
"kid": "1",
},
"ec key with missing attribute [crv]"
),
(
{
"kty": "EC",
"crv": "P-256",
"y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
"use": "enc",
"kid": "1",
},
"ec key with missing attribute [x]"
),
(
{
"kty": "EC",
"crv": "P-256",
"x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
"y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
"e": "AQAB",
"use": "enc",
"kid": "1",
},
"ec key with unexpected attribute [e]"
)
]
for i, (bad_key, reason) in enumerate(bad_keys_table):
try:
JwkSchema(**bad_key)
assert False, f"failed case {i}: parsing should fail due to: {reason}"
except ValueError:
assert True
4 changes: 2 additions & 2 deletions pyeudiw/tests/test_jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pydantic import TypeAdapter

from pyeudiw.jwk import JWK
from pyeudiw.jwk.schemas.jwk import JwkSchema, ECJwkSchema, RSAJwkSchema
from pyeudiw.jwk.schemas.public import ECJwkSchema, RSAJwkSchema, _JwkSchema_T


@pytest.mark.parametrize(
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_export_public_pem():
@pytest.mark.parametrize("key_type", ["EC", "RSA"])
def test_dynamic_schema_validation(key_type):
jwk = JWK(key_type=key_type)
model = TypeAdapter(JwkSchema).validate_python(jwk.as_dict())
model = TypeAdapter(_JwkSchema_T).validate_python(jwk.as_dict())
match key_type:
case "EC":
assert isinstance(model, ECJwkSchema)
Expand Down

0 comments on commit a3662a4

Please sign in to comment.