Skip to content

Commit

Permalink
feat: protected header by construction kwarg and small linting
Browse files Browse the repository at this point in the history
  • Loading branch information
peppelinux committed Oct 10, 2023
1 parent 23f9d1d commit f7205e3
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sd-jwt"
version = "0.9.1"
version = "0.10.0"
description = "The reference implementation of the IETF SD-JWT specification."
authors = ["Daniel Fett <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/sd_jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.1"
__version__ = "0.10.0"
2 changes: 1 addition & 1 deletion src/sd_jwt/bin/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def generate_test_case_data(settings: Dict, testcase_path: Path, type: str):
use_decoys = testcase.get("add_decoy_claims", False)
serialization_format = testcase.get("serialization_format", "compact")
include_default_claims = testcase.get("include_default_claims", True)
extra_header_parameters = testcase.get("extra_header_parameters", None)
extra_header_parameters = testcase.get("extra_header_parameters", {})

claims = {}
if include_default_claims:
Expand Down
12 changes: 6 additions & 6 deletions src/sd_jwt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, error_location: any):


class SDJWTCommon:
SD_JWT_TYP_HEADER = None # "sd+jwt"
SD_JWT_HEADER = "sd+jwt"
KB_JWT_TYP_HEADER = "kb+jwt"
JWS_KEY_DISCLOSURES = "disclosures"
JWS_KEY_KB_JWT = "kb_jwt"
Expand Down Expand Up @@ -91,14 +91,14 @@ def _create_hash_mappings(self, disclosurses_list: List):
decoded_disclosure = loads(
self._base64url_decode(disclosure).decode("utf-8")
)
hash = self._b64hash(disclosure.encode("ascii"))
if hash in self._hash_to_decoded_disclosure:
_hash = self._b64hash(disclosure.encode("ascii"))
if _hash in self._hash_to_decoded_disclosure:
raise ValueError(
f"Duplicate disclosure hash {hash} for disclosure {decoded_disclosure}"
f"Duplicate disclosure hash {_hash} for disclosure {decoded_disclosure}"
)

self._hash_to_decoded_disclosure[hash] = decoded_disclosure
self._hash_to_disclosure[hash] = disclosure
self._hash_to_decoded_disclosure[_hash] = decoded_disclosure
self._hash_to_disclosure[_hash] = disclosure

def _check_for_sd_claim(self, the_object):
# Recursively check for the presence of the _sd claim, also
Expand Down
12 changes: 6 additions & 6 deletions src/sd_jwt/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose):
zip_longest(claims_to_disclose, sd_jwt_claims, fillvalue=None)
):
if (
type(element) is dict
isinstance(element, dict)
and len(element) == 1
and SD_LIST_PREFIX in element
and type(element[SD_LIST_PREFIX]) is str
Expand All @@ -116,11 +116,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose):
continue

self.hs_disclosures.append(self._hash_to_disclosure[digest_to_check])
if type(disclosure_value) is dict:
if isinstance(disclosure_value, dict):
if claims_to_disclose_element is True:
# Tolerate a "True" for a disclosure of an object
claims_to_disclose_element = {}
if not type(claims_to_disclose_element) is dict:
if not isinstance(claims_to_disclose_element, dict):
raise ValueError(
f"To disclose object elements in arrays, provide an object (can be empty).\n"
f"Found {claims_to_disclose_element} instead.\n"
Expand All @@ -130,11 +130,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose):
self._select_disclosures(
disclosure_value, claims_to_disclose_element
)
elif type(disclosure_value) is list:
elif isinstance(disclosure_value, list):
if claims_to_disclose_element is True:
# Tolerate a "True" for a disclosure of an array
claims_to_disclose_element = []
if not type(claims_to_disclose_element) is list:
if not isinstance(claims_to_disclose_element, list):
raise ValueError(
f"To disclose array elements nested in arrays, provide an array (can be empty).\n"
f"Found {claims_to_disclose_element} instead.\n"
Expand All @@ -155,7 +155,7 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose):
if claims_to_disclose is True:
# Tolerate a "True" for a disclosure of an object
claims_to_disclose = {}
if not type(claims_to_disclose) is dict:
if not isinstance(claims_to_disclose, dict):
raise ValueError(
f"To disclose object elements, an object must be provided as disclosure information.\n"
f"Found {claims_to_disclose} (type {type(claims_to_disclose)}) instead.\n"
Expand Down
30 changes: 15 additions & 15 deletions src/sd_jwt/issuer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
sign_alg=None,
add_decoy_claims: bool = False,
serialization_format: str = "compact",
extra_header_parameters: Dict = None,
extra_header_parameters: dict = {},
):
super().__init__(serialization_format=serialization_format)

Expand Down Expand Up @@ -78,21 +78,20 @@ def _create_sd_claims(self, user_claims):
#
# If the user claims are a list, apply this function
# to each item in the list.
if type(user_claims) is list:
if isinstance(user_claims, list):
return self._create_sd_claims_list(user_claims)

# If the user claims are a dictionary, apply this function
# to each key/value pair in the dictionary.
elif type(user_claims) is dict:
elif isinstance(user_claims, dict):
return self._create_sd_claims_object(user_claims)

# For other types, assume that the value can be disclosed.
else:
if isinstance(user_claims, SDObj):
raise ValueError(
f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj."
)
return user_claims
elif isinstance(user_claims, SDObj):
raise ValueError(
f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj."
)
return user_claims

def _create_sd_claims_list(self, user_claims: List):
# Walk through all elements in the list.
Expand Down Expand Up @@ -168,12 +167,13 @@ def _create_signed_jws(self):

self.sd_jwt = JWS(payload=dumps(self.sd_jwt_payload))

# Assemble protected headers
_protected_headers = {"alg": self._sign_alg}
if self.SD_JWT_TYP_HEADER:
_protected_headers["typ"] = self.SD_JWT_TYP_HEADER
if self._extra_header_parameters:
_protected_headers.update(self._extra_header_parameters)
# Assemble protected headers starting with default
_protected_headers = {
"alg": self._sign_alg,
"typ": self.SD_JWT_HEADER
}
# override if any
_protected_headers.update(self._extra_header_parameters)

self.sd_jwt.add_signature(
self._issuer_key,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_disclose_all_shortcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_e2e(testcase, settings):
demo_keys = get_jwk(settings["key_settings"], True, seed)
use_decoys = testcase.get("add_decoy_claims", False)
serialization_format = testcase.get("serialization_format", "compact")
extra_header_parameters = testcase.get("extra_header_parameters", None)
extra_header_parameters = testcase.get("extra_header_parameters", {})

# Issuer: Produce SD-JWT and issuance format for selected example

Expand Down Expand Up @@ -59,7 +59,8 @@ def cb_get_issuer_key(issuer, header_parameters):
assert verified == expected_claims

expected_header_parameters = {
"alg": testcase.get("sign_alg", "ES256")
"alg": testcase.get("sign_alg", "ES256"),
"typ": "sd+jwt"
}
expected_header_parameters.update(extra_header_parameters or {})

Expand Down
5 changes: 3 additions & 2 deletions tests/test_e2e_testcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_e2e(testcase, settings):
demo_keys = get_jwk(settings["key_settings"], True, seed)
use_decoys = testcase.get("add_decoy_claims", False)
serialization_format = testcase.get("serialization_format", "compact")
extra_header_parameters = testcase.get("extra_header_parameters", None)
extra_header_parameters = testcase.get("extra_header_parameters", {})

# Issuer: Produce SD-JWT and issuance format for selected example

Expand Down Expand Up @@ -74,7 +74,8 @@ def cb_get_issuer_key(issuer, header_parameters):
assert verified == expected_claims

expected_header_parameters = {
"alg": testcase.get("sign_alg", "ES256")
"alg": testcase.get("sign_alg", "ES256"),
"typ": "sd+jwt"
}
expected_header_parameters.update(extra_header_parameters or {})

Expand Down

0 comments on commit f7205e3

Please sign in to comment.