Skip to content

Commit

Permalink
Make exception for JSON serialization; add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
danielfett committed Oct 4, 2023
1 parent e3b0b10 commit 06ca411
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 10 deletions.
7 changes: 4 additions & 3 deletions src/sd_jwt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DEFAULT_SIGNING_ALG = "ES256"
SD_DIGESTS_KEY = "_sd"
DIGEST_ALG_KEY = "_sd_alg"
KB_DIGEST_KEY = "_sd_hash"
SD_LIST_PREFIX = "..."


Expand Down Expand Up @@ -39,7 +40,7 @@ class SDJWTCommon:
JWS_KEY_KB_JWT = "kb_jwt"
HASH_ALG = {"name": "sha-256", "fn": sha256}

COMBINED_serialization_FORMAT_SEPARATOR = "~"
COMBINED_SERIALIZATION_FORMAT_SEPARATOR = "~"

unsafe_randomness = False

Expand All @@ -53,10 +54,10 @@ def _b64hash(self, raw):
return self._base64url_encode(self.HASH_ALG["fn"](raw).digest())

def _combine(self, *parts):
return self.COMBINED_serialization_FORMAT_SEPARATOR.join(parts)
return self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR.join(parts)

def _split(self, combined):
return combined.split(self.COMBINED_serialization_FORMAT_SEPARATOR)
return combined.split(self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR)

@staticmethod
def _base64url_encode(data: bytes) -> str:
Expand Down
16 changes: 12 additions & 4 deletions src/sd_jwt/holder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .common import SDJWTCommon, DEFAULT_SIGNING_ALG, SD_DIGESTS_KEY, SD_LIST_PREFIX
from .common import SDJWTCommon, DEFAULT_SIGNING_ALG, SD_DIGESTS_KEY, SD_LIST_PREFIX, KB_DIGEST_KEY
from json import dumps, loads
from time import time
from typing import Dict, List, Optional
Expand Down Expand Up @@ -42,10 +42,17 @@ def create_presentation(

# Optional: Create a key binding JWT
if nonce and aud and holder_key:
self._create_key_binding_jwt(nonce, aud, holder_key, sign_alg)
# Temporarily create the combined presentation in order to create the hash over it
string_to_hash = self._combine(
self.serialized_sd_jwt,
*self.hs_disclosures,
""
)
sd_jwt_presentation_hash = self._b64hash(string_to_hash.encode("ascii"))
self._create_key_binding_jwt(nonce, aud, sd_jwt_presentation_hash, holder_key, sign_alg)

# Create the combined presentation

# Create the combined presentation
if self._serialization_format == "compact":
# Note: If the key binding JWT is not created, then the
# last element is empty, matching the spec.
Expand Down Expand Up @@ -194,7 +201,7 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose):
self._select_disclosures(value, claims_to_disclose.get(key, None))

def _create_key_binding_jwt(
self, nonce, aud, holder_key, sign_alg: Optional[str] = None
self, nonce, aud, presentation_hash, holder_key, sign_alg: Optional[str] = None
):
_alg = sign_alg or DEFAULT_SIGNING_ALG

Expand All @@ -207,6 +214,7 @@ def _create_key_binding_jwt(
"nonce": nonce,
"aud": aud,
"iat": int(time()),
KB_DIGEST_KEY: presentation_hash,
}

# Sign the SD-JWT-Release using the holder's key
Expand Down
2 changes: 1 addition & 1 deletion src/sd_jwt/issuer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,6 @@ def _create_combined(self):
self.sd_jwt_issuance = self._combine(
self.serialized_sd_jwt, *(d.b64 for d in self.ii_disclosures)
)
self.sd_jwt_issuance += self.COMBINED_serialization_FORMAT_SEPARATOR
self.sd_jwt_issuance += self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR
else:
self.sd_jwt_issuance = self.serialized_sd_jwt
22 changes: 20 additions & 2 deletions src/sd_jwt/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DIGEST_ALG_KEY,
SD_DIGESTS_KEY,
SD_LIST_PREFIX,
KB_DIGEST_KEY,
)

from json import dumps, loads
Expand Down Expand Up @@ -72,10 +73,13 @@ def _verify_key_binding_jwt(
expected_nonce: Union[str, None] = None,
sign_alg: Union[str, None] = None,
):

# Deserialized the key binding JWT
_alg = sign_alg or DEFAULT_SIGNING_ALG
parsed_input_key_binding_jwt = JWS()
parsed_input_key_binding_jwt.deserialize(self._unverified_input_key_binding_jwt)

# Verify the key binding JWT using the holder public key
if not self._holder_public_key_payload:
raise ValueError("No holder public key in SD-JWT")

Expand All @@ -91,17 +95,31 @@ def _verify_key_binding_jwt(

parsed_input_key_binding_jwt.verify(pubkey, alg=_alg)

# Check header typ
key_binding_jwt_header = parsed_input_key_binding_jwt.jose_header

if key_binding_jwt_header["typ"] != self.KB_JWT_TYP_HEADER:
raise ValueError("Invalid header typ")

# Check payload
key_binding_jwt_payload = loads(parsed_input_key_binding_jwt.payload)

if key_binding_jwt_payload["aud"] != expected_aud:
raise ValueError("Invalid audience")
raise ValueError("Invalid audience in KB-JWT")
if key_binding_jwt_payload["nonce"] != expected_nonce:
raise ValueError("Invalid nonce")
raise ValueError("Invalid nonce in KB-JWT")

# Reassemble the SD-JWT in compact format and check digest
if self._serialization_format == "compact":
string_to_hash = self._combine(
self._unverified_input_sd_jwt,
*self._input_disclosures,
""
)
expected_sd_jwt_presentation_hash = self._b64hash(string_to_hash.encode("ascii"))

if key_binding_jwt_payload[KB_DIGEST_KEY] != expected_sd_jwt_presentation_hash:
raise ValueError("Invalid digest in KB-JWT")

def _extract_sd_claims(self):
if DIGEST_ALG_KEY in self._sd_jwt_payload:
Expand Down

0 comments on commit 06ca411

Please sign in to comment.