From 06ca411afd00e70576f8856bd1ea62a3729b1fad Mon Sep 17 00:00:00 2001 From: Daniel Fett Date: Wed, 4 Oct 2023 15:19:30 +0200 Subject: [PATCH] Make exception for JSON serialization; add test case --- src/sd_jwt/common.py | 7 ++++--- src/sd_jwt/holder.py | 16 ++++++++++++---- src/sd_jwt/issuer.py | 2 +- src/sd_jwt/verifier.py | 22 ++++++++++++++++++++-- 4 files changed, 37 insertions(+), 10 deletions(-) diff --git a/src/sd_jwt/common.py b/src/sd_jwt/common.py index 1398e69..ef58aae 100644 --- a/src/sd_jwt/common.py +++ b/src/sd_jwt/common.py @@ -9,6 +9,7 @@ DEFAULT_SIGNING_ALG = "ES256" SD_DIGESTS_KEY = "_sd" DIGEST_ALG_KEY = "_sd_alg" +KB_DIGEST_KEY = "_sd_hash" SD_LIST_PREFIX = "..." @@ -39,7 +40,7 @@ class SDJWTCommon: JWS_KEY_KB_JWT = "kb_jwt" HASH_ALG = {"name": "sha-256", "fn": sha256} - COMBINED_serialization_FORMAT_SEPARATOR = "~" + COMBINED_SERIALIZATION_FORMAT_SEPARATOR = "~" unsafe_randomness = False @@ -53,10 +54,10 @@ def _b64hash(self, raw): return self._base64url_encode(self.HASH_ALG["fn"](raw).digest()) def _combine(self, *parts): - return self.COMBINED_serialization_FORMAT_SEPARATOR.join(parts) + return self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR.join(parts) def _split(self, combined): - return combined.split(self.COMBINED_serialization_FORMAT_SEPARATOR) + return combined.split(self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR) @staticmethod def _base64url_encode(data: bytes) -> str: diff --git a/src/sd_jwt/holder.py b/src/sd_jwt/holder.py index 605f8c2..ed22168 100644 --- a/src/sd_jwt/holder.py +++ b/src/sd_jwt/holder.py @@ -1,4 +1,4 @@ -from .common import SDJWTCommon, DEFAULT_SIGNING_ALG, SD_DIGESTS_KEY, SD_LIST_PREFIX +from .common import SDJWTCommon, DEFAULT_SIGNING_ALG, SD_DIGESTS_KEY, SD_LIST_PREFIX, KB_DIGEST_KEY from json import dumps, loads from time import time from typing import Dict, List, Optional @@ -42,10 +42,17 @@ def create_presentation( # Optional: Create a key binding JWT if nonce and aud and holder_key: - self._create_key_binding_jwt(nonce, aud, holder_key, sign_alg) + # Temporarily create the combined presentation in order to create the hash over it + string_to_hash = self._combine( + self.serialized_sd_jwt, + *self.hs_disclosures, + "" + ) + sd_jwt_presentation_hash = self._b64hash(string_to_hash.encode("ascii")) + self._create_key_binding_jwt(nonce, aud, sd_jwt_presentation_hash, holder_key, sign_alg) - # Create the combined presentation + # Create the combined presentation if self._serialization_format == "compact": # Note: If the key binding JWT is not created, then the # last element is empty, matching the spec. @@ -194,7 +201,7 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose): self._select_disclosures(value, claims_to_disclose.get(key, None)) def _create_key_binding_jwt( - self, nonce, aud, holder_key, sign_alg: Optional[str] = None + self, nonce, aud, presentation_hash, holder_key, sign_alg: Optional[str] = None ): _alg = sign_alg or DEFAULT_SIGNING_ALG @@ -207,6 +214,7 @@ def _create_key_binding_jwt( "nonce": nonce, "aud": aud, "iat": int(time()), + KB_DIGEST_KEY: presentation_hash, } # Sign the SD-JWT-Release using the holder's key diff --git a/src/sd_jwt/issuer.py b/src/sd_jwt/issuer.py index 44f1845..3b5acb8 100644 --- a/src/sd_jwt/issuer.py +++ b/src/sd_jwt/issuer.py @@ -198,6 +198,6 @@ def _create_combined(self): self.sd_jwt_issuance = self._combine( self.serialized_sd_jwt, *(d.b64 for d in self.ii_disclosures) ) - self.sd_jwt_issuance += self.COMBINED_serialization_FORMAT_SEPARATOR + self.sd_jwt_issuance += self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR else: self.sd_jwt_issuance = self.serialized_sd_jwt diff --git a/src/sd_jwt/verifier.py b/src/sd_jwt/verifier.py index 3eb82ef..7f0400d 100644 --- a/src/sd_jwt/verifier.py +++ b/src/sd_jwt/verifier.py @@ -4,6 +4,7 @@ DIGEST_ALG_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX, + KB_DIGEST_KEY, ) from json import dumps, loads @@ -72,10 +73,13 @@ def _verify_key_binding_jwt( expected_nonce: Union[str, None] = None, sign_alg: Union[str, None] = None, ): + + # Deserialized the key binding JWT _alg = sign_alg or DEFAULT_SIGNING_ALG parsed_input_key_binding_jwt = JWS() parsed_input_key_binding_jwt.deserialize(self._unverified_input_key_binding_jwt) + # Verify the key binding JWT using the holder public key if not self._holder_public_key_payload: raise ValueError("No holder public key in SD-JWT") @@ -91,17 +95,31 @@ def _verify_key_binding_jwt( parsed_input_key_binding_jwt.verify(pubkey, alg=_alg) + # Check header typ key_binding_jwt_header = parsed_input_key_binding_jwt.jose_header if key_binding_jwt_header["typ"] != self.KB_JWT_TYP_HEADER: raise ValueError("Invalid header typ") + # Check payload key_binding_jwt_payload = loads(parsed_input_key_binding_jwt.payload) if key_binding_jwt_payload["aud"] != expected_aud: - raise ValueError("Invalid audience") + raise ValueError("Invalid audience in KB-JWT") if key_binding_jwt_payload["nonce"] != expected_nonce: - raise ValueError("Invalid nonce") + raise ValueError("Invalid nonce in KB-JWT") + + # Reassemble the SD-JWT in compact format and check digest + if self._serialization_format == "compact": + string_to_hash = self._combine( + self._unverified_input_sd_jwt, + *self._input_disclosures, + "" + ) + expected_sd_jwt_presentation_hash = self._b64hash(string_to_hash.encode("ascii")) + + if key_binding_jwt_payload[KB_DIGEST_KEY] != expected_sd_jwt_presentation_hash: + raise ValueError("Invalid digest in KB-JWT") def _extract_sd_claims(self): if DIGEST_ALG_KEY in self._sd_jwt_payload: