Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: example+sd-jwt default typ value plus small linting #10

Merged
merged 2 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sd-jwt"
version = "0.9.1"
version = "0.10.0"
description = "The reference implementation of the IETF SD-JWT specification."
authors = ["Daniel Fett <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/sd_jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.1"
__version__ = "0.10.0"
2 changes: 1 addition & 1 deletion src/sd_jwt/bin/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def generate_test_case_data(settings: Dict, testcase_path: Path, type: str):
use_decoys = testcase.get("add_decoy_claims", False)
serialization_format = testcase.get("serialization_format", "compact")
include_default_claims = testcase.get("include_default_claims", True)
extra_header_parameters = testcase.get("extra_header_parameters", None)
extra_header_parameters = testcase.get("extra_header_parameters", {})

claims = {}
if include_default_claims:
Expand Down
23 changes: 14 additions & 9 deletions src/sd_jwt/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
import os
import random
import secrets

from base64 import urlsafe_b64decode, urlsafe_b64encode
from dataclasses import dataclass
from hashlib import sha256
Expand All @@ -11,6 +14,8 @@
DIGEST_ALG_KEY = "_sd_alg"
SD_LIST_PREFIX = "..."

logger = logging.getLogger("sd_jwt")


@dataclass
class SDObj:
Expand All @@ -33,7 +38,7 @@ def __init__(self, error_location: any):


class SDJWTCommon:
SD_JWT_TYP_HEADER = None # "sd+jwt"
SD_JWT_HEADER = os.getenv("SD_JWT_HEADER", "example+sd-jwt") # overwriteable with extra_header_parameters = {"typ": "other-example+sd-jwt"}
KB_JWT_TYP_HEADER = "kb+jwt"
JWS_KEY_DISCLOSURES = "disclosures"
JWS_KEY_KB_JWT = "kb_jwt"
Expand Down Expand Up @@ -71,8 +76,8 @@ def _generate_salt(self):
if self.unsafe_randomness:
# This is not cryptographically secure, but it is deterministic
# and allows for repeatable output for the generation of the examples.
print(
"WARNING: Using unsafe randomness - output is not suitable for production use!"
logger.warning(
"Using unsafe randomness is not suitable for production use."
)
return self._base64url_encode(
bytes(random.getrandbits(8) for _ in range(16))
Expand All @@ -91,14 +96,14 @@ def _create_hash_mappings(self, disclosurses_list: List):
decoded_disclosure = loads(
self._base64url_decode(disclosure).decode("utf-8")
)
hash = self._b64hash(disclosure.encode("ascii"))
if hash in self._hash_to_decoded_disclosure:
_hash = self._b64hash(disclosure.encode("ascii"))
if _hash in self._hash_to_decoded_disclosure:
raise ValueError(
f"Duplicate disclosure hash {hash} for disclosure {decoded_disclosure}"
f"Duplicate disclosure hash {_hash} for disclosure {decoded_disclosure}"
)

self._hash_to_decoded_disclosure[hash] = decoded_disclosure
self._hash_to_disclosure[hash] = disclosure
self._hash_to_decoded_disclosure[_hash] = decoded_disclosure
self._hash_to_disclosure[_hash] = disclosure

def _check_for_sd_claim(self, the_object):
# Recursively check for the presence of the _sd claim, also
Expand All @@ -121,7 +126,7 @@ def _parse_sd_jwt(self, sd_jwt):
(
self._unverified_input_sd_jwt,
*self._input_disclosures,
self._unverified_input_key_binding_jwt,
self._unverified_input_key_binding_jwt
) = self._split(sd_jwt)

# Extract only the body from SD-JWT without verifying the signature
Expand Down
22 changes: 13 additions & 9 deletions src/sd_jwt/holder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from .common import SDJWTCommon, DEFAULT_SIGNING_ALG, SD_DIGESTS_KEY, SD_LIST_PREFIX
from json import dumps, loads
from time import time
Expand All @@ -6,6 +8,8 @@

from jwcrypto.jws import JWS

logger = logging.getLogger("sd_jwt")


class SDJWTHolder(SDJWTCommon):
hs_disclosures: List
Expand Down Expand Up @@ -94,7 +98,7 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose):
zip_longest(claims_to_disclose, sd_jwt_claims, fillvalue=None)
):
if (
type(element) is dict
isinstance(element, dict)
and len(element) == 1
and SD_LIST_PREFIX in element
and type(element[SD_LIST_PREFIX]) is str
Expand All @@ -116,11 +120,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose):
continue

self.hs_disclosures.append(self._hash_to_disclosure[digest_to_check])
if type(disclosure_value) is dict:
if isinstance(disclosure_value, dict):
if claims_to_disclose_element is True:
# Tolerate a "True" for a disclosure of an object
claims_to_disclose_element = {}
if not type(claims_to_disclose_element) is dict:
if not isinstance(claims_to_disclose_element, dict):
raise ValueError(
f"To disclose object elements in arrays, provide an object (can be empty).\n"
f"Found {claims_to_disclose_element} instead.\n"
Expand All @@ -130,11 +134,11 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose):
self._select_disclosures(
disclosure_value, claims_to_disclose_element
)
elif type(disclosure_value) is list:
elif isinstance(disclosure_value, list):
if claims_to_disclose_element is True:
# Tolerate a "True" for a disclosure of an array
claims_to_disclose_element = []
if not type(claims_to_disclose_element) is list:
if not isinstance(claims_to_disclose_element, list):
raise ValueError(
f"To disclose array elements nested in arrays, provide an array (can be empty).\n"
f"Found {claims_to_disclose_element} instead.\n"
Expand All @@ -155,7 +159,7 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose):
if claims_to_disclose is True:
# Tolerate a "True" for a disclosure of an object
claims_to_disclose = {}
if not type(claims_to_disclose) is dict:
if not isinstance(claims_to_disclose, dict):
raise ValueError(
f"To disclose object elements, an object must be provided as disclosure information.\n"
f"Found {claims_to_disclose} (type {type(claims_to_disclose)}) instead.\n"
Expand All @@ -170,16 +174,16 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose):
_, key, value = self._hash_to_decoded_disclosure[digest_to_check]

try:
print(
logger.debug(
f"In _select_disclosures_dict: {key}, {value}, {claims_to_disclose}"
)
if key in claims_to_disclose and claims_to_disclose[key]:
print(f"Adding disclosure for {digest_to_check}")
logger.debug(f"Adding disclosure for {digest_to_check}")
self.hs_disclosures.append(
self._hash_to_disclosure[digest_to_check]
)
else:
print(
logger.debug(
f"Not adding disclosure for {digest_to_check}, {key} (type {type(key)}) not in {claims_to_disclose}"
)
except TypeError:
Expand Down
30 changes: 15 additions & 15 deletions src/sd_jwt/issuer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
sign_alg=None,
add_decoy_claims: bool = False,
serialization_format: str = "compact",
extra_header_parameters: Dict = None,
extra_header_parameters: dict = {},
):
super().__init__(serialization_format=serialization_format)

Expand Down Expand Up @@ -78,21 +78,20 @@ def _create_sd_claims(self, user_claims):
#
# If the user claims are a list, apply this function
# to each item in the list.
if type(user_claims) is list:
if isinstance(user_claims, list):
return self._create_sd_claims_list(user_claims)

# If the user claims are a dictionary, apply this function
# to each key/value pair in the dictionary.
elif type(user_claims) is dict:
elif isinstance(user_claims, dict):
return self._create_sd_claims_object(user_claims)

# For other types, assume that the value can be disclosed.
else:
if isinstance(user_claims, SDObj):
raise ValueError(
f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj."
)
return user_claims
elif isinstance(user_claims, SDObj):
raise ValueError(
f"SDObj found in illegal place.\nThe claim value '{user_claims}' should not be wrapped by SDObj."
)
return user_claims

def _create_sd_claims_list(self, user_claims: List):
# Walk through all elements in the list.
Expand Down Expand Up @@ -168,12 +167,13 @@ def _create_signed_jws(self):

self.sd_jwt = JWS(payload=dumps(self.sd_jwt_payload))

# Assemble protected headers
_protected_headers = {"alg": self._sign_alg}
if self.SD_JWT_TYP_HEADER:
_protected_headers["typ"] = self.SD_JWT_TYP_HEADER
if self._extra_header_parameters:
_protected_headers.update(self._extra_header_parameters)
# Assemble protected headers starting with default
_protected_headers = {
"alg": self._sign_alg,
"typ": self.SD_JWT_HEADER
}
# override if any
_protected_headers.update(self._extra_header_parameters)

self.sd_jwt.add_signature(
self._issuer_key,
Expand Down
2 changes: 1 addition & 1 deletion src/sd_jwt/utils/demo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jwcrypto.jwk import JWK
from typing import Union

logger = logging.getLogger(__name__)
logger = logging.getLogger("sd_jwt")


def load_yaml_settings(file):
Expand Down
2 changes: 1 addition & 1 deletion src/sd_jwt/utils/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def textwrap_json(data, width=EXAMPLE_MAX_WIDTH):
else:
# Check if line is of the form "key": "value"
if not line.strip().startswith('"'):
print("WARNING: unexpected line " + line)
logger.warning("unexpected line " + line)
output.append(line)
continue
# Determine number of spaces before the value
Expand Down
14 changes: 9 additions & 5 deletions tests/test_disclose_all_shortcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def test_e2e(testcase, settings):
demo_keys = get_jwk(settings["key_settings"], True, seed)
use_decoys = testcase.get("add_decoy_claims", False)
serialization_format = testcase.get("serialization_format", "compact")
extra_header_parameters = testcase.get("extra_header_parameters", None)

extra_header_parameters = {"typ": "testcase+sd-jwt"}
extra_header_parameters.update(testcase.get("extra_header_parameters", {}))

# Issuer: Produce SD-JWT and issuance format for selected example

Expand Down Expand Up @@ -39,7 +41,6 @@ def cb_get_issuer_key(issuer, header_parameters):
sdjwt_header_parameters.update(header_parameters)
return demo_keys["issuer_public_key"]


sdjwt_at_verifier = SDJWTVerifier(
output_holder,
cb_get_issuer_key,
Expand All @@ -54,13 +55,16 @@ def cb_get_issuer_key(issuer, header_parameters):
expected_claims["iss"] = settings["identifiers"]["issuer"]

if testcase.get("key_binding", False):
expected_claims["cnf"] = {"jwk": demo_keys["holder_key"].export_public(as_dict=True)}
expected_claims["cnf"] = {
"jwk": demo_keys["holder_key"].export_public(as_dict=True)
}

assert verified == expected_claims

expected_header_parameters = {
"alg": testcase.get("sign_alg", "ES256")
"alg": testcase.get("sign_alg", "ES256"),
"typ": "testcase+sd-jwt"
}
expected_header_parameters.update(extra_header_parameters or {})
expected_header_parameters.update(extra_header_parameters)

assert sdjwt_header_parameters == expected_header_parameters
9 changes: 6 additions & 3 deletions tests/test_e2e_testcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def test_e2e(testcase, settings):
demo_keys = get_jwk(settings["key_settings"], True, seed)
use_decoys = testcase.get("add_decoy_claims", False)
serialization_format = testcase.get("serialization_format", "compact")
extra_header_parameters = testcase.get("extra_header_parameters", None)

extra_header_parameters = {"typ": "testcase+sd-jwt"}
extra_header_parameters.update(testcase.get("extra_header_parameters", {}))

# Issuer: Produce SD-JWT and issuance format for selected example

Expand Down Expand Up @@ -74,8 +76,9 @@ def cb_get_issuer_key(issuer, header_parameters):
assert verified == expected_claims

expected_header_parameters = {
"alg": testcase.get("sign_alg", "ES256")
"alg": testcase.get("sign_alg", "ES256"),
"typ": "testcase+sd-jwt"
}
expected_header_parameters.update(extra_header_parameters or {})
expected_header_parameters.update(extra_header_parameters)

assert sdjwt_header_parameters == expected_header_parameters