From ccde23b523e2135a3a62fdce32b569d55afd1a2c Mon Sep 17 00:00:00 2001 From: Giuseppe De Marco Date: Fri, 27 Oct 2023 11:19:27 +0200 Subject: [PATCH 1/2] feat: protected header kwarg, small linting and logging Signed-off-by: Giuseppe De Marco --- pyproject.toml | 2 +- src/sd_jwt/__init__.py | 2 +- src/sd_jwt/bin/generate.py | 2 +- src/sd_jwt/common.py | 19 ++++++++++-------- src/sd_jwt/holder.py | 22 ++++++++++++--------- src/sd_jwt/issuer.py | 30 ++++++++++++++--------------- src/sd_jwt/utils/demo_utils.py | 2 +- src/sd_jwt/utils/formatting.py | 2 +- tests/test_disclose_all_shortcut.py | 12 ++++++++---- tests/test_e2e_testcases.py | 9 ++++++--- 10 files changed, 58 insertions(+), 44 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f3d3e8c..4cd3ab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sd-jwt" -version = "0.9.1" +version = "0.10.0" description = "The reference implementation of the IETF SD-JWT specification." authors = ["Daniel Fett "] readme = "README.md" diff --git a/src/sd_jwt/__init__.py b/src/sd_jwt/__init__.py index d69d16e..61fb31c 100644 --- a/src/sd_jwt/__init__.py +++ b/src/sd_jwt/__init__.py @@ -1 +1 @@ -__version__ = "0.9.1" +__version__ = "0.10.0" diff --git a/src/sd_jwt/bin/generate.py b/src/sd_jwt/bin/generate.py index 01341c5..97d0c6b 100755 --- a/src/sd_jwt/bin/generate.py +++ b/src/sd_jwt/bin/generate.py @@ -39,7 +39,7 @@ def generate_test_case_data(settings: Dict, testcase_path: Path, type: str): use_decoys = testcase.get("add_decoy_claims", False) serialization_format = testcase.get("serialization_format", "compact") include_default_claims = testcase.get("include_default_claims", True) - extra_header_parameters = testcase.get("extra_header_parameters", None) + extra_header_parameters = testcase.get("extra_header_parameters") claims = {} if include_default_claims: diff --git a/src/sd_jwt/common.py b/src/sd_jwt/common.py index 1398e69..6c8d3e4 100644 --- a/src/sd_jwt/common.py +++ b/src/sd_jwt/common.py @@ -1,3 +1,4 @@ +import logging import random import secrets from base64 import urlsafe_b64decode, urlsafe_b64encode @@ -11,6 +12,8 @@ DIGEST_ALG_KEY = "_sd_alg" SD_LIST_PREFIX = "..." +logger = logging.getLogger("sd_jwt") + @dataclass class SDObj: @@ -33,7 +36,7 @@ def __init__(self, error_location: any): class SDJWTCommon: - SD_JWT_TYP_HEADER = None # "sd+jwt" + SD_JWT_HEADER = None # overwriteable with extra_header_parameters = {"typ": "other-example+sd-jwt"} KB_JWT_TYP_HEADER = "kb+jwt" JWS_KEY_DISCLOSURES = "disclosures" JWS_KEY_KB_JWT = "kb_jwt" @@ -71,8 +74,8 @@ def _generate_salt(self): if self.unsafe_randomness: # This is not cryptographically secure, but it is deterministic # and allows for repeatable output for the generation of the examples. - print( - "WARNING: Using unsafe randomness - output is not suitable for production use!" + logger.warning( + "Using unsafe randomness is not suitable for production use." ) return self._base64url_encode( bytes(random.getrandbits(8) for _ in range(16)) @@ -91,14 +94,14 @@ def _create_hash_mappings(self, disclosurses_list: List): decoded_disclosure = loads( self._base64url_decode(disclosure).decode("utf-8") ) - hash = self._b64hash(disclosure.encode("ascii")) - if hash in self._hash_to_decoded_disclosure: + _hash = self._b64hash(disclosure.encode("ascii")) + if _hash in self._hash_to_decoded_disclosure: raise ValueError( - f"Duplicate disclosure hash {hash} for disclosure {decoded_disclosure}" + f"Duplicate disclosure hash {_hash} for disclosure {decoded_disclosure}" ) - self._hash_to_decoded_disclosure[hash] = decoded_disclosure - self._hash_to_disclosure[hash] = disclosure + self._hash_to_decoded_disclosure[_hash] = decoded_disclosure + self._hash_to_disclosure[_hash] = disclosure def _check_for_sd_claim(self, the_object): # Recursively check for the presence of the _sd claim, also diff --git a/src/sd_jwt/holder.py b/src/sd_jwt/holder.py index 605f8c2..f1dc6c8 100644 --- a/src/sd_jwt/holder.py +++ b/src/sd_jwt/holder.py @@ -1,3 +1,5 @@ +import logging + from .common import SDJWTCommon, DEFAULT_SIGNING_ALG, SD_DIGESTS_KEY, SD_LIST_PREFIX from json import dumps, loads from time import time @@ -6,6 +8,8 @@ from jwcrypto.jws import JWS +logger = logging.getLogger("sd_jwt") + class SDJWTHolder(SDJWTCommon): hs_disclosures: List @@ -94,7 +98,7 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): zip_longest(claims_to_disclose, sd_jwt_claims, fillvalue=None) ): if ( - type(element) is dict + isinstance(element, dict) and len(element) == 1 and SD_LIST_PREFIX in element and type(element[SD_LIST_PREFIX]) is str @@ -116,11 +120,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): continue self.hs_disclosures.append(self._hash_to_disclosure[digest_to_check]) - if type(disclosure_value) is dict: + if isinstance(disclosure_value, dict): if claims_to_disclose_element is True: # Tolerate a "True" for a disclosure of an object claims_to_disclose_element = {} - if not type(claims_to_disclose_element) is dict: + if not isinstance(claims_to_disclose_element, dict): raise ValueError( f"To disclose object elements in arrays, provide an object (can be empty).\n" f"Found {claims_to_disclose_element} instead.\n" @@ -130,11 +134,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): self._select_disclosures( disclosure_value, claims_to_disclose_element ) - elif type(disclosure_value) is list: + elif isinstance(disclosure_value, list): if claims_to_disclose_element is True: # Tolerate a "True" for a disclosure of an array claims_to_disclose_element = [] - if not type(claims_to_disclose_element) is list: + if not isinstance(claims_to_disclose_element, list): raise ValueError( f"To disclose array elements nested in arrays, provide an array (can be empty).\n" f"Found {claims_to_disclose_element} instead.\n" @@ -155,7 +159,7 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose): if claims_to_disclose is True: # Tolerate a "True" for a disclosure of an object claims_to_disclose = {} - if not type(claims_to_disclose) is dict: + if not isinstance(claims_to_disclose, dict): raise ValueError( f"To disclose object elements, an object must be provided as disclosure information.\n" f"Found {claims_to_disclose} (type {type(claims_to_disclose)}) instead.\n" @@ -170,16 +174,16 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose): _, key, value = self._hash_to_decoded_disclosure[digest_to_check] try: - print( + logger.debug( f"In _select_disclosures_dict: {key}, {value}, {claims_to_disclose}" ) if key in claims_to_disclose and claims_to_disclose[key]: - print(f"Adding disclosure for {digest_to_check}") + logger.debug(f"Adding disclosure for {digest_to_check}") self.hs_disclosures.append( self._hash_to_disclosure[digest_to_check] ) else: - print( + logger.debug( f"Not adding disclosure for {digest_to_check}, {key} (type {type(key)}) not in {claims_to_disclose}" ) except TypeError: diff --git a/src/sd_jwt/issuer.py b/src/sd_jwt/issuer.py index 44f1845..5b49fdf 100644 --- a/src/sd_jwt/issuer.py +++ b/src/sd_jwt/issuer.py @@ -36,7 +36,7 @@ def __init__( sign_alg=None, add_decoy_claims: bool = False, serialization_format: str = "compact", - extra_header_parameters: Dict = None, + extra_header_parameters: dict = {}, ): super().__init__(serialization_format=serialization_format) @@ -78,21 +78,20 @@ def _create_sd_claims(self, user_claims): # # If the user claims are a list, apply this function # to each item in the list. - if type(user_claims) is list: + if isinstance(user_claims, list): return self._create_sd_claims_list(user_claims) # If the user claims are a dictionary, apply this function # to each key/value pair in the dictionary. - elif type(user_claims) is dict: + elif isinstance(user_claims, dict): return self._create_sd_claims_object(user_claims) # For other types, assume that the value can be disclosed. - else: - if isinstance(user_claims, SDObj): - raise ValueError( - f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj." - ) - return user_claims + elif isinstance(user_claims, SDObj): + raise ValueError( + f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj." + ) + return user_claims def _create_sd_claims_list(self, user_claims: List): # Walk through all elements in the list. @@ -168,12 +167,13 @@ def _create_signed_jws(self): self.sd_jwt = JWS(payload=dumps(self.sd_jwt_payload)) - # Assemble protected headers - _protected_headers = {"alg": self._sign_alg} - if self.SD_JWT_TYP_HEADER: - _protected_headers["typ"] = self.SD_JWT_TYP_HEADER - if self._extra_header_parameters: - _protected_headers.update(self._extra_header_parameters) + # Assemble protected headers starting with default + _protected_headers = { + "alg": self._sign_alg, + "typ": self.SD_JWT_HEADER + } + # override if any + _protected_headers.update(self._extra_header_parameters) self.sd_jwt.add_signature( self._issuer_key, diff --git a/src/sd_jwt/utils/demo_utils.py b/src/sd_jwt/utils/demo_utils.py index dd7656c..fee6632 100644 --- a/src/sd_jwt/utils/demo_utils.py +++ b/src/sd_jwt/utils/demo_utils.py @@ -8,7 +8,7 @@ from jwcrypto.jwk import JWK from typing import Union -logger = logging.getLogger(__name__) +logger = logging.getLogger("sd_jwt") def load_yaml_settings(file): diff --git a/src/sd_jwt/utils/formatting.py b/src/sd_jwt/utils/formatting.py index 08e4fe8..e7bc6da 100644 --- a/src/sd_jwt/utils/formatting.py +++ b/src/sd_jwt/utils/formatting.py @@ -26,7 +26,7 @@ def textwrap_json(data, width=EXAMPLE_MAX_WIDTH): else: # Check if line is of the form "key": "value" if not line.strip().startswith('"'): - print("WARNING: unexpected line " + line) + logger.warning("unexpected line " + line) output.append(line) continue # Determine number of spaces before the value diff --git a/tests/test_disclose_all_shortcut.py b/tests/test_disclose_all_shortcut.py index e1eac5b..05477bf 100644 --- a/tests/test_disclose_all_shortcut.py +++ b/tests/test_disclose_all_shortcut.py @@ -10,7 +10,9 @@ def test_e2e(testcase, settings): demo_keys = get_jwk(settings["key_settings"], True, seed) use_decoys = testcase.get("add_decoy_claims", False) serialization_format = testcase.get("serialization_format", "compact") - extra_header_parameters = testcase.get("extra_header_parameters", None) + + extra_header_parameters = {"typ": "testcase+sd-jwt"} + extra_header_parameters.update(testcase.get("extra_header_parameters", {})) # Issuer: Produce SD-JWT and issuance format for selected example @@ -39,7 +41,6 @@ def cb_get_issuer_key(issuer, header_parameters): sdjwt_header_parameters.update(header_parameters) return demo_keys["issuer_public_key"] - sdjwt_at_verifier = SDJWTVerifier( output_holder, cb_get_issuer_key, @@ -54,12 +55,15 @@ def cb_get_issuer_key(issuer, header_parameters): expected_claims["iss"] = settings["identifiers"]["issuer"] if testcase.get("key_binding", False): - expected_claims["cnf"] = {"jwk": demo_keys["holder_key"].export_public(as_dict=True)} + expected_claims["cnf"] = { + "jwk": demo_keys["holder_key"].export_public(as_dict=True) + } assert verified == expected_claims expected_header_parameters = { - "alg": testcase.get("sign_alg", "ES256") + "alg": testcase.get("sign_alg", "ES256"), + "typ": "testcase+sd-jwt" } expected_header_parameters.update(extra_header_parameters or {}) diff --git a/tests/test_e2e_testcases.py b/tests/test_e2e_testcases.py index f6ca529..04d881a 100644 --- a/tests/test_e2e_testcases.py +++ b/tests/test_e2e_testcases.py @@ -10,7 +10,9 @@ def test_e2e(testcase, settings): demo_keys = get_jwk(settings["key_settings"], True, seed) use_decoys = testcase.get("add_decoy_claims", False) serialization_format = testcase.get("serialization_format", "compact") - extra_header_parameters = testcase.get("extra_header_parameters", None) + + extra_header_parameters = {"typ": "testcase+sd-jwt"} + extra_header_parameters.update(testcase.get("extra_header_parameters", {})) # Issuer: Produce SD-JWT and issuance format for selected example @@ -74,8 +76,9 @@ def cb_get_issuer_key(issuer, header_parameters): assert verified == expected_claims expected_header_parameters = { - "alg": testcase.get("sign_alg", "ES256") + "alg": testcase.get("sign_alg", "ES256"), + "typ": "testcase+sd-jwt" } - expected_header_parameters.update(extra_header_parameters or {}) + expected_header_parameters.update(extra_header_parameters) assert sdjwt_header_parameters == expected_header_parameters From 18d25faaa41f85a0125bc089051356270b250c2b Mon Sep 17 00:00:00 2001 From: Giuseppe De Marco Date: Fri, 27 Oct 2023 11:43:50 +0200 Subject: [PATCH 2/2] feat: SD_JWT_HEADER configurable via ENV var Signed-off-by: Giuseppe De Marco --- src/sd_jwt/bin/generate.py | 2 +- src/sd_jwt/common.py | 6 ++++-- tests/test_disclose_all_shortcut.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/sd_jwt/bin/generate.py b/src/sd_jwt/bin/generate.py index 97d0c6b..ad00641 100755 --- a/src/sd_jwt/bin/generate.py +++ b/src/sd_jwt/bin/generate.py @@ -39,7 +39,7 @@ def generate_test_case_data(settings: Dict, testcase_path: Path, type: str): use_decoys = testcase.get("add_decoy_claims", False) serialization_format = testcase.get("serialization_format", "compact") include_default_claims = testcase.get("include_default_claims", True) - extra_header_parameters = testcase.get("extra_header_parameters") + extra_header_parameters = testcase.get("extra_header_parameters", {}) claims = {} if include_default_claims: diff --git a/src/sd_jwt/common.py b/src/sd_jwt/common.py index 6c8d3e4..0eeefe3 100644 --- a/src/sd_jwt/common.py +++ b/src/sd_jwt/common.py @@ -1,6 +1,8 @@ import logging +import os import random import secrets + from base64 import urlsafe_b64decode, urlsafe_b64encode from dataclasses import dataclass from hashlib import sha256 @@ -36,7 +38,7 @@ def __init__(self, error_location: any): class SDJWTCommon: - SD_JWT_HEADER = None # overwriteable with extra_header_parameters = {"typ": "other-example+sd-jwt"} + SD_JWT_HEADER = os.getenv("SD_JWT_HEADER", "example+sd-jwt") # overwriteable with extra_header_parameters = {"typ": "other-example+sd-jwt"} KB_JWT_TYP_HEADER = "kb+jwt" JWS_KEY_DISCLOSURES = "disclosures" JWS_KEY_KB_JWT = "kb_jwt" @@ -124,7 +126,7 @@ def _parse_sd_jwt(self, sd_jwt): ( self._unverified_input_sd_jwt, *self._input_disclosures, - self._unverified_input_key_binding_jwt, + self._unverified_input_key_binding_jwt ) = self._split(sd_jwt) # Extract only the body from SD-JWT without verifying the signature diff --git a/tests/test_disclose_all_shortcut.py b/tests/test_disclose_all_shortcut.py index 05477bf..51579c5 100644 --- a/tests/test_disclose_all_shortcut.py +++ b/tests/test_disclose_all_shortcut.py @@ -65,6 +65,6 @@ def cb_get_issuer_key(issuer, header_parameters): "alg": testcase.get("sign_alg", "ES256"), "typ": "testcase+sd-jwt" } - expected_header_parameters.update(extra_header_parameters or {}) + expected_header_parameters.update(extra_header_parameters) assert sdjwt_header_parameters == expected_header_parameters