Skip to content

Commit

Permalink
feat: example+sd-jwt default typ value plus small linting (#10)
Browse files Browse the repository at this point in the history
* feat: protected header kwarg, small linting and logging

Signed-off-by: Giuseppe De Marco <[email protected]>

* feat: SD_JWT_HEADER configurable via ENV var

Signed-off-by: Giuseppe De Marco <[email protected]>

---------

Signed-off-by: Giuseppe De Marco <[email protected]>
  • Loading branch information
peppelinux authored Oct 27, 2023
1 parent 23f9d1d commit 956dc43
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sd-jwt"
version = "0.9.1"
version = "0.10.0"
description = "The reference implementation of the IETF SD-JWT specification."
authors = ["Daniel Fett <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/sd_jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.1"
__version__ = "0.10.0"
2 changes: 1 addition & 1 deletion src/sd_jwt/bin/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def generate_test_case_data(settings: Dict, testcase_path: Path, type: str):
use_decoys = testcase.get("add_decoy_claims", False)
serialization_format = testcase.get("serialization_format", "compact")
include_default_claims = testcase.get("include_default_claims", True)
extra_header_parameters = testcase.get("extra_header_parameters", None)
extra_header_parameters = testcase.get("extra_header_parameters", {})

claims = {}
if include_default_claims:
Expand Down
23 changes: 14 additions & 9 deletions src/sd_jwt/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
import os
import random
import secrets

from base64 import urlsafe_b64decode, urlsafe_b64encode
from dataclasses import dataclass
from hashlib import sha256
Expand All @@ -11,6 +14,8 @@
DIGEST_ALG_KEY = "_sd_alg"
SD_LIST_PREFIX = "..."

logger = logging.getLogger("sd_jwt")


@dataclass
class SDObj:
Expand All @@ -33,7 +38,7 @@ def __init__(self, error_location: any):


class SDJWTCommon:
SD_JWT_TYP_HEADER = None # "sd+jwt"
SD_JWT_HEADER = os.getenv("SD_JWT_HEADER", "example+sd-jwt") # overwriteable with extra_header_parameters = {"typ": "other-example+sd-jwt"}
KB_JWT_TYP_HEADER = "kb+jwt"
JWS_KEY_DISCLOSURES = "disclosures"
JWS_KEY_KB_JWT = "kb_jwt"
Expand Down Expand Up @@ -71,8 +76,8 @@ def _generate_salt(self):
if self.unsafe_randomness:
# This is not cryptographically secure, but it is deterministic
# and allows for repeatable output for the generation of the examples.
print(
"WARNING: Using unsafe randomness - output is not suitable for production use!"
logger.warning(
"Using unsafe randomness is not suitable for production use."
)
return self._base64url_encode(
bytes(random.getrandbits(8) for _ in range(16))
Expand All @@ -91,14 +96,14 @@ def _create_hash_mappings(self, disclosurses_list: List):
decoded_disclosure = loads(
self._base64url_decode(disclosure).decode("utf-8")
)
hash = self._b64hash(disclosure.encode("ascii"))
if hash in self._hash_to_decoded_disclosure:
_hash = self._b64hash(disclosure.encode("ascii"))
if _hash in self._hash_to_decoded_disclosure:
raise ValueError(
f"Duplicate disclosure hash {hash} for disclosure {decoded_disclosure}"
f"Duplicate disclosure hash {_hash} for disclosure {decoded_disclosure}"
)

self._hash_to_decoded_disclosure[hash] = decoded_disclosure
self._hash_to_disclosure[hash] = disclosure
self._hash_to_decoded_disclosure[_hash] = decoded_disclosure
self._hash_to_disclosure[_hash] = disclosure

def _check_for_sd_claim(self, the_object):
# Recursively check for the presence of the _sd claim, also
Expand All @@ -121,7 +126,7 @@ def _parse_sd_jwt(self, sd_jwt):
(
self._unverified_input_sd_jwt,
*self._input_disclosures,
self._unverified_input_key_binding_jwt,
self._unverified_input_key_binding_jwt
) = self._split(sd_jwt)

# Extract only the body from SD-JWT without verifying the signature
Expand Down
22 changes: 13 additions & 9 deletions src/sd_jwt/holder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from .common import SDJWTCommon, DEFAULT_SIGNING_ALG, SD_DIGESTS_KEY, SD_LIST_PREFIX
from json import dumps, loads
from time import time
Expand All @@ -6,6 +8,8 @@

from jwcrypto.jws import JWS

logger = logging.getLogger("sd_jwt")


class SDJWTHolder(SDJWTCommon):
hs_disclosures: List
Expand Down Expand Up @@ -94,7 +98,7 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose):
zip_longest(claims_to_disclose, sd_jwt_claims, fillvalue=None)
):
if (
type(element) is dict
isinstance(element, dict)
and len(element) == 1
and SD_LIST_PREFIX in element
and type(element[SD_LIST_PREFIX]) is str
Expand All @@ -116,11 +120,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose):
continue

self.hs_disclosures.append(self._hash_to_disclosure[digest_to_check])
if type(disclosure_value) is dict:
if isinstance(disclosure_value, dict):
if claims_to_disclose_element is True:
# Tolerate a "True" for a disclosure of an object
claims_to_disclose_element = {}
if not type(claims_to_disclose_element) is dict:
if not isinstance(claims_to_disclose_element, dict):
raise ValueError(
f"To disclose object elements in arrays, provide an object (can be empty).\n"
f"Found {claims_to_disclose_element} instead.\n"
Expand All @@ -130,11 +134,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose):
self._select_disclosures(
disclosure_value, claims_to_disclose_element
)
elif type(disclosure_value) is list:
elif isinstance(disclosure_value, list):
if claims_to_disclose_element is True:
# Tolerate a "True" for a disclosure of an array
claims_to_disclose_element = []
if not type(claims_to_disclose_element) is list:
if not isinstance(claims_to_disclose_element, list):
raise ValueError(
f"To disclose array elements nested in arrays, provide an array (can be empty).\n"
f"Found {claims_to_disclose_element} instead.\n"
Expand All @@ -155,7 +159,7 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose):
if claims_to_disclose is True:
# Tolerate a "True" for a disclosure of an object
claims_to_disclose = {}
if not type(claims_to_disclose) is dict:
if not isinstance(claims_to_disclose, dict):
raise ValueError(
f"To disclose object elements, an object must be provided as disclosure information.\n"
f"Found {claims_to_disclose} (type {type(claims_to_disclose)}) instead.\n"
Expand All @@ -170,16 +174,16 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose):
_, key, value = self._hash_to_decoded_disclosure[digest_to_check]

try:
print(
logger.debug(
f"In _select_disclosures_dict: {key}, {value}, {claims_to_disclose}"
)
if key in claims_to_disclose and claims_to_disclose[key]:
print(f"Adding disclosure for {digest_to_check}")
logger.debug(f"Adding disclosure for {digest_to_check}")
self.hs_disclosures.append(
self._hash_to_disclosure[digest_to_check]
)
else:
print(
logger.debug(
f"Not adding disclosure for {digest_to_check}, {key} (type {type(key)}) not in {claims_to_disclose}"
)
except TypeError:
Expand Down
30 changes: 15 additions & 15 deletions src/sd_jwt/issuer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
sign_alg=None,
add_decoy_claims: bool = False,
serialization_format: str = "compact",
extra_header_parameters: Dict = None,
extra_header_parameters: dict = {},
):
super().__init__(serialization_format=serialization_format)

Expand Down Expand Up @@ -78,21 +78,20 @@ def _create_sd_claims(self, user_claims):
#
# If the user claims are a list, apply this function
# to each item in the list.
if type(user_claims) is list:
if isinstance(user_claims, list):
return self._create_sd_claims_list(user_claims)

# If the user claims are a dictionary, apply this function
# to each key/value pair in the dictionary.
elif type(user_claims) is dict:
elif isinstance(user_claims, dict):
return self._create_sd_claims_object(user_claims)

# For other types, assume that the value can be disclosed.
else:
if isinstance(user_claims, SDObj):
raise ValueError(
f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj."
)
return user_claims
elif isinstance(user_claims, SDObj):
raise ValueError(
f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj."
)
return user_claims

def _create_sd_claims_list(self, user_claims: List):
# Walk through all elements in the list.
Expand Down Expand Up @@ -168,12 +167,13 @@ def _create_signed_jws(self):

self.sd_jwt = JWS(payload=dumps(self.sd_jwt_payload))

# Assemble protected headers
_protected_headers = {"alg": self._sign_alg}
if self.SD_JWT_TYP_HEADER:
_protected_headers["typ"] = self.SD_JWT_TYP_HEADER
if self._extra_header_parameters:
_protected_headers.update(self._extra_header_parameters)
# Assemble protected headers starting with default
_protected_headers = {
"alg": self._sign_alg,
"typ": self.SD_JWT_HEADER
}
# override if any
_protected_headers.update(self._extra_header_parameters)

self.sd_jwt.add_signature(
self._issuer_key,
Expand Down
2 changes: 1 addition & 1 deletion src/sd_jwt/utils/demo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jwcrypto.jwk import JWK
from typing import Union

logger = logging.getLogger(__name__)
logger = logging.getLogger("sd_jwt")


def load_yaml_settings(file):
Expand Down
2 changes: 1 addition & 1 deletion src/sd_jwt/utils/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def textwrap_json(data, width=EXAMPLE_MAX_WIDTH):
else:
# Check if line is of the form "key": "value"
if not line.strip().startswith('"'):
print("WARNING: unexpected line " + line)
logger.warning("unexpected line " + line)
output.append(line)
continue
# Determine number of spaces before the value
Expand Down
14 changes: 9 additions & 5 deletions tests/test_disclose_all_shortcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def test_e2e(testcase, settings):
demo_keys = get_jwk(settings["key_settings"], True, seed)
use_decoys = testcase.get("add_decoy_claims", False)
serialization_format = testcase.get("serialization_format", "compact")
extra_header_parameters = testcase.get("extra_header_parameters", None)

extra_header_parameters = {"typ": "testcase+sd-jwt"}
extra_header_parameters.update(testcase.get("extra_header_parameters", {}))

# Issuer: Produce SD-JWT and issuance format for selected example

Expand Down Expand Up @@ -39,7 +41,6 @@ def cb_get_issuer_key(issuer, header_parameters):
sdjwt_header_parameters.update(header_parameters)
return demo_keys["issuer_public_key"]


sdjwt_at_verifier = SDJWTVerifier(
output_holder,
cb_get_issuer_key,
Expand All @@ -54,13 +55,16 @@ def cb_get_issuer_key(issuer, header_parameters):
expected_claims["iss"] = settings["identifiers"]["issuer"]

if testcase.get("key_binding", False):
expected_claims["cnf"] = {"jwk": demo_keys["holder_key"].export_public(as_dict=True)}
expected_claims["cnf"] = {
"jwk": demo_keys["holder_key"].export_public(as_dict=True)
}

assert verified == expected_claims

expected_header_parameters = {
"alg": testcase.get("sign_alg", "ES256")
"alg": testcase.get("sign_alg", "ES256"),
"typ": "testcase+sd-jwt"
}
expected_header_parameters.update(extra_header_parameters or {})
expected_header_parameters.update(extra_header_parameters)

assert sdjwt_header_parameters == expected_header_parameters
9 changes: 6 additions & 3 deletions tests/test_e2e_testcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def test_e2e(testcase, settings):
demo_keys = get_jwk(settings["key_settings"], True, seed)
use_decoys = testcase.get("add_decoy_claims", False)
serialization_format = testcase.get("serialization_format", "compact")
extra_header_parameters = testcase.get("extra_header_parameters", None)

extra_header_parameters = {"typ": "testcase+sd-jwt"}
extra_header_parameters.update(testcase.get("extra_header_parameters", {}))

# Issuer: Produce SD-JWT and issuance format for selected example

Expand Down Expand Up @@ -74,8 +76,9 @@ def cb_get_issuer_key(issuer, header_parameters):
assert verified == expected_claims

expected_header_parameters = {
"alg": testcase.get("sign_alg", "ES256")
"alg": testcase.get("sign_alg", "ES256"),
"typ": "testcase+sd-jwt"
}
expected_header_parameters.update(extra_header_parameters or {})
expected_header_parameters.update(extra_header_parameters)

assert sdjwt_header_parameters == expected_header_parameters

0 comments on commit 956dc43

Please sign in to comment.