diff --git a/pyeudiw/federation/schemas/entity_configuration.py b/pyeudiw/federation/schemas/entity_configuration.py index 8b08c6dd..901cc429 100644 --- a/pyeudiw/federation/schemas/entity_configuration.py +++ b/pyeudiw/federation/schemas/entity_configuration.py @@ -5,7 +5,7 @@ from pyeudiw.federation.schemas.federation_entity import FederationEntity from pyeudiw.federation.schemas.wallet_relying_party import WalletRelyingParty -from pyeudiw.jwk.schema import JwksSchema +from pyeudiw.jwk.schemas.public import JwksSchema from pyeudiw.tools.schema_utils import check_algorithm diff --git a/pyeudiw/federation/schemas/federation_configuration.py b/pyeudiw/federation/schemas/federation_configuration.py index 7755a0dd..8c3edcb7 100644 --- a/pyeudiw/federation/schemas/federation_configuration.py +++ b/pyeudiw/federation/schemas/federation_configuration.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, HttpUrl from pyeudiw.federation.schemas.wallet_relying_party import SigningAlgValuesSupported -from pyeudiw.jwk.schemas.jwk import JwkSchema +from pyeudiw.jwk.schemas.public import JwkSchema class FederationEntityMetadata(BaseModel): diff --git a/pyeudiw/federation/schemas/wallet_relying_party.py b/pyeudiw/federation/schemas/wallet_relying_party.py index 91e6df85..ac69c52f 100644 --- a/pyeudiw/federation/schemas/wallet_relying_party.py +++ b/pyeudiw/federation/schemas/wallet_relying_party.py @@ -1,6 +1,6 @@ from enum import Enum from typing import List -from pyeudiw.jwk.schemas.jwk import JwksSchema +from pyeudiw.jwk.schemas.public import JwksSchema from pydantic import BaseModel, HttpUrl, PositiveInt from pyeudiw.openid4vp.schemas.vp_formats import VpFormats from pyeudiw.presentation_exchange.schemas.oid4vc_presentation_definition import PresentationDefinition diff --git a/pyeudiw/jwk/schemas/jwk.py b/pyeudiw/jwk/schemas/jwk.py deleted file mode 100644 index 753958ea..00000000 --- a/pyeudiw/jwk/schemas/jwk.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Literal, Annotated, Union, Optional, List - -from pydantic import BaseModel, Field - - -class JwkBaseModel(BaseModel): - use: Optional[Literal["sig", "enc"]] = None - kid: Optional[str] = None - - -class RSAJwkSchema(JwkBaseModel): - kty: Literal["RSA"] - n: str - e: str - - -class ECJwkSchema(JwkBaseModel): - kty: Literal["EC"] - crv: Literal["P-256", "P-384", "P-521"] - x: str - y: str - - -JwkSchema = Annotated[Union[ECJwkSchema, RSAJwkSchema], - Field(discriminator="kty")] - - -class JwksSchema(BaseModel): - keys: List[JwkSchema] diff --git a/pyeudiw/jwk/schemas/public.py b/pyeudiw/jwk/schemas/public.py new file mode 100644 index 00000000..7997affe --- /dev/null +++ b/pyeudiw/jwk/schemas/public.py @@ -0,0 +1,120 @@ +from typing import Annotated, List, Literal, Optional, Union + +from pydantic import BaseModel, Field, field_validator + +_SUPPORTED_KTY = Literal["EC", "RSA"] + +_SUPPORTED_ALGS = Literal[ + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + "RS256", + "RS384", + "RS512", +] + +_SUPPORTED_ALG_BY_KTY = { + "RSA": ("PS256", "PS384", "PS512", "RS256", "RS384", "RS512"), + "EC": ("ES256", "ES384", "ES512") +} + +# TODO: supported alg by kty and use + +_SUPPORTED_CRVS = Literal[ + "P-256", + "P-384", + "P-521", + "brainpoolP256r1", + "brainpoolP384r1", + "brainpoolP512r1" +] + + +class JwkBaseModel(BaseModel): + kid: Optional[str] + use: Optional[Literal["sig", "enc"]] = None + + +class ECJwkSchema(JwkBaseModel): + kty: Literal["EC"] + crv: _SUPPORTED_CRVS + x: str + y: str + + +class RSAJwkSchema(JwkBaseModel): + kty: Literal["RSA"] + n: str + e: str + + +class JwkSchema(BaseModel): + kid: str # Base64url-encoded thumbprint string + kty: _SUPPORTED_KTY + alg: Annotated[Union[_SUPPORTED_ALGS, None], Field(validate_default=True)] = None + use: Annotated[Union[Literal["sig", "enc"], None], Field(validate_default=True)] = None + n: Annotated[Union[str, None], Field(validate_default=True)] = None # Base64urlUInt-encoded + e: Annotated[Union[str, None], Field(validate_default=True)] = None # Base64urlUInt-encoded + x: Annotated[Union[str, None], Field(validate_default=True)] = None # Base64urlUInt-encoded + y: Annotated[Union[str, None], Field(validate_default=True)] = None # Base64urlUInt-encoded + crv: Annotated[Union[_SUPPORTED_CRVS, None], Field(validate_default=True)] = None + + def _must_specific_kty_only(v, exp_kty: _SUPPORTED_ALGS, v_name: str, values: dict): + """validate a jwk parameter by that it is (1) defined and (2) mandatory + only for one specific kty by checking that it is indeed defined by when + kty matches. + """ + err_msg = f"{v_name} must be present only for kty = {exp_kty}" + obt_kty: Union[_SUPPORTED_KTY, None] = values.get("kty", None) + if obt_kty is None: + if v is not None: + raise ValueError("unexpected validation state: missing kty") + return + if exp_kty == obt_kty: + if v is None: + raise ValueError(err_msg) + return + # in this validation v should NOT be defined if obt_kty != exp_kty + if v is not None: + raise ValueError(err_msg) + return + + @field_validator("alg") + def validate_alg(cls, v, values): + if v is None: + return + kty = values.data.get("kty") + if v not in _SUPPORTED_ALG_BY_KTY[kty]: + raise ValueError(f"alg value {v} is not compatible or not supported with kty {kty}") + return + + @field_validator("n") + def validate_n(cls, v, values): + cls._must_specific_kty_only(v, "RSA", "n", values.data) + + @field_validator("e") + def valisate_e(cls, v, values): + cls._must_specific_kty_only(v, "RSA", "e", values.data) + + @field_validator("x") + def validate_x(cls, v, values): + cls._must_specific_kty_only(v, "EC", "x", values.data) + + @field_validator("y") + def validate_y(cls, v, values): + cls._must_specific_kty_only(v, "EC", "y", values.data) + + @field_validator("crv") + def validate_crv(cls, v, values): + cls._must_specific_kty_only(v, "EC", "crv", values.data) + + +_JwkSchema_T = Annotated[Union[ECJwkSchema, RSAJwkSchema], + Field(discriminator="kty")] + + +class JwksSchema(BaseModel): + keys: List[_JwkSchema_T] diff --git a/pyeudiw/oauth2/dpop/__init__.py b/pyeudiw/oauth2/dpop/__init__.py index 486083e8..068822aa 100644 --- a/pyeudiw/oauth2/dpop/__init__.py +++ b/pyeudiw/oauth2/dpop/__init__.py @@ -3,7 +3,7 @@ import logging import uuid -from pyeudiw.jwk.schema import JwkSchema +from pyeudiw.jwk.schemas.public import JwkSchema from pyeudiw.oauth2.dpop.exceptions import ( InvalidDPoP, InvalidDPoPAth, diff --git a/pyeudiw/oauth2/dpop/schema.py b/pyeudiw/oauth2/dpop/schema.py index 22ad3c57..45688d4b 100644 --- a/pyeudiw/oauth2/dpop/schema.py +++ b/pyeudiw/oauth2/dpop/schema.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, HttpUrl -from pyeudiw.jwk.schemas.jwk import JwkSchema +from pyeudiw.jwk.schemas.public import JwkSchema class DPoPTokenHeaderSchema(BaseModel): diff --git a/pyeudiw/oauth2/par/__init__.py b/pyeudiw/oauth2/par/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pyeudiw/openid4vp/schemas/cnf_schema.py b/pyeudiw/openid4vp/schemas/cnf_schema.py index 2ed0f8f4..27e9ceaf 100644 --- a/pyeudiw/openid4vp/schemas/cnf_schema.py +++ b/pyeudiw/openid4vp/schemas/cnf_schema.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from pyeudiw.jwk.schema import JwkSchema +from pyeudiw.jwk.schemas.public import JwkSchema class CNFSchema(BaseModel): diff --git a/pyeudiw/satosa/schemas/config.py b/pyeudiw/satosa/schemas/config.py index 97f2264e..bcec2516 100644 --- a/pyeudiw/satosa/schemas/config.py +++ b/pyeudiw/satosa/schemas/config.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from pyeudiw.jwk.schemas.jwk import JwkSchema +from pyeudiw.jwk.schemas.public import JwkSchema from pyeudiw.satosa.schemas.endpoint import EndpointsConfig from pyeudiw.satosa.schemas.qrcode import QRCode from pyeudiw.satosa.schemas.autorization import AuthorizationConfig diff --git a/pyeudiw/sd_jwt/schema.py b/pyeudiw/sd_jwt/schema.py index 4d86d8f1..ab905e63 100644 --- a/pyeudiw/sd_jwt/schema.py +++ b/pyeudiw/sd_jwt/schema.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, HttpUrl -from pyeudiw.jwk.schema import JwkSchema +from pyeudiw.jwk.schemas.public import JwkSchema SD_JWT_REGEXP = r"^(([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*)(~([-A-Za-z0-9\=_\.])*)*$" diff --git a/pyeudiw/tests/jwk/test_schema.py b/pyeudiw/tests/jwk/test_schema.py new file mode 100644 index 00000000..f3d997fc --- /dev/null +++ b/pyeudiw/tests/jwk/test_schema.py @@ -0,0 +1,96 @@ +from pyeudiw.jwk.schemas.public import JwkSchema, ECJwkSchema, RSAJwkSchema + + +def test_valid_rsa_jwk(): + jwk_d = { + "kty": "RSA", + "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", + "e": "AQAB", + "alg": "RS256", + "kid": "2011-04-29", + } + JwkSchema(**jwk_d) + RSAJwkSchema(**jwk_d) + + +def test_valid_ec_jwk(): + jwk_d = { + "kty": "EC", + "crv": "P-256", + "x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", + "use": "enc", + "kid": "1", + } + JwkSchema(**jwk_d) + ECJwkSchema(**jwk_d) + + +def test_invalid_keys(): + # table with keys that should fail jwk parsing + bad_keys_table: list[tuple[dict, str]] = [ + ( + { + "aaaa": "1" + }, + "non-sense key" + ), + ( + { + "kty": "RSA", + "e": "AQAB", + "alg": "RS256", + "kid": "2011-04-29", + }, + "rsa key with missing attribute [n]" + ), + ( + { + "kty": "RSA", + "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", + "e": "AQAB", + "x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "alg": "RS256", + "kid": "2011-04-29", + }, + "rsa key with unexpected attribute [x]" + ), + ( + { + "kty": "EC", + "x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", + "use": "enc", + "kid": "1", + }, + "ec key with missing attribute [crv]" + ), + ( + { + "kty": "EC", + "crv": "P-256", + "y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", + "use": "enc", + "kid": "1", + }, + "ec key with missing attribute [x]" + ), + ( + { + "kty": "EC", + "crv": "P-256", + "x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", + "e": "AQAB", + "use": "enc", + "kid": "1", + }, + "ec key with unexpected attribute [e]" + ) + ] + for i, (bad_key, reason) in enumerate(bad_keys_table): + try: + JwkSchema(**bad_key) + assert False, f"failed case {i}: parsing should fail due to: {reason}" + except ValueError: + assert True diff --git a/pyeudiw/tests/test_jwk.py b/pyeudiw/tests/test_jwk.py index 8c8adfaf..40bb24e9 100644 --- a/pyeudiw/tests/test_jwk.py +++ b/pyeudiw/tests/test_jwk.py @@ -2,7 +2,7 @@ from pydantic import TypeAdapter from pyeudiw.jwk import JWK -from pyeudiw.jwk.schemas.jwk import JwkSchema, ECJwkSchema, RSAJwkSchema +from pyeudiw.jwk.schemas.public import ECJwkSchema, RSAJwkSchema, _JwkSchema_T @pytest.mark.parametrize( @@ -57,7 +57,7 @@ def test_export_public_pem(): @pytest.mark.parametrize("key_type", ["EC", "RSA"]) def test_dynamic_schema_validation(key_type): jwk = JWK(key_type=key_type) - model = TypeAdapter(JwkSchema).validate_python(jwk.as_dict()) + model = TypeAdapter(_JwkSchema_T).validate_python(jwk.as_dict()) match key_type: case "EC": assert isinstance(model, ECJwkSchema)