diff --git a/.travis.yml b/.travis.yml index bcafc11..ea3a3ba 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,7 @@ addons: - rustc - cargo install: +- pip install --upgrade pip - pip install codecov - pip install isort - pip install tox @@ -24,12 +25,3 @@ after_success: - codecov notifications: email: false -deploy: - provider: pypi - on: - tags: true - distributions: bdist_wheel - skip_existing: true - user: __token__ - password: - secure: AH3jGGXVjV/oOlg4cOnsN7pURlZ7JMcd3Prr69Q++rxfsrmFpxCPtQLpO0LUNPisfyctoImpY64auNMHh20AHdlnvXQu8k/YFZCVcyK6N2d66wgJbO9AOT21N6IkFGyW11K3lYIHzURv9RsTEhzSkOhKmPUack5UhSJ+yAUTZXpt6iZqXBvmxMNzNiCLQdUmTMj4HxxkUVPabpef8PLqyDXvAxJxOCss+QcJVZuWFs85Niw0scTkU4SWz2lhOxeqQNg8s+CEgje2KaIoRy2kETywK53G3RFkSp5ytIJPp8RQK039laeal5yjMsWP4KlbDhHrywyNN7yS69FwPuLC41ppde5G054WcuJTm60Y2uckGu6L3oTBMHsAtSfZuEym/qfDngxYADA+xrATJQF5XSrCz13IiBnoz8Y9zI7t9s66PZSBHg99L85jM45M2kJYCDKxNPffJ/JzCnAMTP0yiBMEQ/UfguMDfJMw+6oSPzGcZHuQVzjLO5mUni71X528Psd/iEYCyN+Vi1QbvDZjbNo/oOLtvegOcnu/H1tGWkH4uEXsg2giqkld2hrZi6K3KfcpPtltuP66Z6ohMqcLegqGUNr8mPMP2I58p7if/6xLEu1e7MNZuoV459bnWepoNMMug2NLq/WIPCiGLNCyV4tdbzqcZQLNwjc5ruFLoz8= diff --git a/src/oidcmsg/__init__.py b/src/oidcmsg/__init__.py index 72fd00d..5ab331f 100755 --- a/src/oidcmsg/__init__.py +++ b/src/oidcmsg/__init__.py @@ -1,13 +1,13 @@ -__author__ = 'Roland Hedberg' -__version__ = '1.2.0' +__author__ = "Roland Hedberg" +__version__ = "1.3.0" import os -VERIFIED_CLAIM_PREFIX = '__verified' +VERIFIED_CLAIM_PREFIX = "__verified" def verified_claim_name(claim): - return '{}_{}'.format(VERIFIED_CLAIM_PREFIX, claim) + return "{}_{}".format(VERIFIED_CLAIM_PREFIX, claim) def proper_path(path): diff --git a/src/oidcmsg/context.py b/src/oidcmsg/context.py index 46f70d0..f4e9e5a 100644 --- a/src/oidcmsg/context.py +++ b/src/oidcmsg/context.py @@ -11,11 +11,11 @@ def add_issuer(conf, issuer): res = {} for key, val in conf.items(): - if key == 'abstract_storage_cls': + if key == "abstract_storage_cls": res[key] = val else: _val = copy.copy(val) - _val['issuer'] = quote_plus(issuer) + _val["issuer"] = quote_plus(issuer) res[key] = _val return res @@ -23,7 +23,7 @@ def add_issuer(conf, issuer): class OidcContext(ImpExp): parameter = {"keyjar": KeyJar, "issuer": None} - def __init__(self, config=None, keyjar=None, entity_id=''): + def __init__(self, config=None, keyjar=None, entity_id=""): ImpExp.__init__(self) if config is None: config = {} @@ -31,21 +31,21 @@ def __init__(self, config=None, keyjar=None, entity_id=''): self.issuer = entity_id self.keyjar = self._keyjar(keyjar, conf=config, entity_id=entity_id) - def _keyjar(self, keyjar=None, conf=None, entity_id=''): + def _keyjar(self, keyjar=None, conf=None, entity_id=""): if keyjar is None: - if 'keys' in conf: + if "keys" in conf: args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} _keyjar = init_key_jar(**args) else: _keyjar = KeyJar() - if 'jwks' in conf: - _keyjar.import_jwks(conf['jwks'], '') + if "jwks" in conf: + _keyjar.import_jwks(conf["jwks"], "") - if '' in _keyjar and entity_id: + if "" in _keyjar and entity_id: # make sure I have the keys under my own name too (if I know it) - _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ''), entity_id) + _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) - _httpc_params = conf.get('httpc_params') + _httpc_params = conf.get("httpc_params") if _httpc_params: _keyjar.httpc_params = _httpc_params diff --git a/src/oidcmsg/exception.py b/src/oidcmsg/exception.py index 3a14442..d885b10 100755 --- a/src/oidcmsg/exception.py +++ b/src/oidcmsg/exception.py @@ -1,4 +1,4 @@ -__author__ = 'Roland Hedberg' +__author__ = "Roland Hedberg" class OidcMsgError(Exception): diff --git a/src/oidcmsg/impexp.py b/src/oidcmsg/impexp.py index 2a8dca0..c0b01c5 100644 --- a/src/oidcmsg/impexp.py +++ b/src/oidcmsg/impexp.py @@ -1,6 +1,8 @@ +from typing import Any from typing import List from typing import Optional +from cryptojwt.utils import as_bytes from cryptojwt.utils import importer from cryptojwt.utils import qualified_name @@ -9,19 +11,24 @@ class ImpExp: parameter = {} + special_load_dump = {} + init_args = [] def __init__(self): pass - def _dump(self, cls, item, exclude_attributes: Optional[List[str]] = None) -> dict: - if cls in [None, "", [], {}]: - val = item + def dump_attr(self, cls, item, exclude_attributes: Optional[List[str]] = None) -> dict: + if cls in [None, 0, "", [], {}, bool, b'']: + if cls == b'': + val = as_bytes(item) + else: + val = item elif isinstance(item, Message): val = {qualified_name(item.__class__): item.to_dict()} elif cls == object: val = qualified_name(item) elif isinstance(cls, list): - val = [self._dump(cls[0], v, exclude_attributes) for v in item] + val = [self.dump_attr(cls[0], v, exclude_attributes) for v in item] else: val = item.dump(exclude_attributes=exclude_attributes) @@ -31,42 +38,98 @@ def dump(self, exclude_attributes: Optional[List[str]] = None) -> dict: _exclude_attributes = exclude_attributes or [] info = {} for attr, cls in self.parameter.items(): - if attr in _exclude_attributes: + if attr in _exclude_attributes or attr in self.special_load_dump: continue item = getattr(self, attr, None) if item is None: continue - info[attr] = self._dump(cls, item, exclude_attributes) + info[attr] = self.dump_attr(cls, item, exclude_attributes) + + for attr, d in self.special_load_dump.items(): + item = getattr(self, attr, None) + if item: + info[attr] = d["dump"](item, exclude_attributes=exclude_attributes) return info - def _local_adjustments(self): + def local_load_adjustments(self, **kwargs): pass - def _load(self, cls, item): - if cls in [None, "", [], {}]: - val = item + def load_attr( + self, + cls: Any, + item: dict, + init_args: Optional[dict] = None, + load_args: Optional[dict] = None, + ) -> Any: + if load_args: + _kwargs = {"load_args": load_args} + _load_args = load_args + else: + _kwargs = {} + _load_args = {} + + if init_args: + _kwargs["init_args"] = init_args + + if cls in [None, 0, "", [], {}, bool, b'']: + if cls == b'': + val = as_bytes(item) + else: + val = item elif cls == object: val = importer(item) elif isinstance(cls, list): - val = [cls[0]().load(v) for v in item] + if isinstance(cls[0], str): + _cls = importer(cls[0]) + else: + _cls = cls[0] + + if issubclass(_cls, ImpExp) and init_args: + _args = {k: v for k, v in init_args.items() if k in _cls.init_args} + else: + _args = {} + + val = [_cls(**_args).load(v, **_kwargs) for v in item] elif issubclass(cls, Message): - val = cls().from_dict(item) + _cls_name = list(item.keys())[0] + _cls = importer(_cls_name) + val = _cls().from_dict(item[_cls_name]) else: - val = cls().load(item) + if issubclass(cls, ImpExp) and init_args: + _args = {k: v for k, v in init_args.items() if k in cls.init_args} + else: + _args = {} + + val = cls(**_args).load(item, **_kwargs) return val - def load(self, item: dict): + def load(self, item: dict, init_args: Optional[dict] = None, load_args: Optional[dict] = None): + + if load_args: + _kwargs = {"load_args": load_args} + _load_args = load_args + else: + _kwargs = {} + _load_args = {} + + if init_args: + _kwargs["init_args"] = init_args + for attr, cls in self.parameter.items(): - if attr not in item: + if attr not in item or attr in self.special_load_dump: continue - setattr(self, attr, self._load(cls, item[attr])) + setattr(self, attr, self.load_attr(cls, item[attr], **_kwargs)) + + for attr, func in self.special_load_dump.items(): + if attr in item: + setattr(self, attr, func["load"](item[attr], **_kwargs)) - self._local_adjustments() + self.local_load_adjustments(**_load_args) return self def flush(self): @@ -78,6 +141,10 @@ def flush(self): for attr, cls in self.parameter.items(): if cls is None: setattr(self, attr, None) + elif cls == 0: + setattr(self, attr, 0) + elif cls is bool: + setattr(self, attr, False) elif cls == "": setattr(self, attr, "") elif cls == []: diff --git a/src/oidcmsg/item.py b/src/oidcmsg/item.py new file mode 100644 index 0000000..77e90cd --- /dev/null +++ b/src/oidcmsg/item.py @@ -0,0 +1,132 @@ +from typing import List +from typing import Optional + +from oidcmsg.impexp import ImpExp +from oidcmsg.message import Message +from oidcmsg.storage import importer +from oidcmsg.storage.utils import qualified_name + + +class DLDict(ImpExp): + parameter = {"db": {}} + + def __init__(self, **kwargs): + ImpExp.__init__(self) + self.db = kwargs + + def __setitem__(self, key: str, val): + self.db[key] = val + + def __getitem__(self, key: str): + return self.db[key] + + def __delitem__(self, key: str): + del self.db[key] + + def dump(self, exclude_attributes: Optional[List[str]] = None) -> dict: + res = {} + + for k, v in self.db.items(): + _class = qualified_name(v.__class__) + res[k] = [_class, v.dump(exclude_attributes=exclude_attributes)] + + return res + + def load( + self, spec: dict, init_args: Optional[dict] = None, load_args: Optional[dict] = None + ) -> "DLDict": + if load_args: + _kwargs = {"load_args": load_args} + _load_args = {} + else: + _load_args = {} + _kwargs = {} + + if init_args: + _kwargs["init_args"] = init_args + + for attr, (_item_cls, _item) in spec.items(): + _cls = importer(_item_cls) + + if issubclass(_cls, ImpExp) and init_args: + _args = {k: v for k, v in init_args.items() if k in _cls.init_args} + else: + _args = {} + + _x = _cls(**_args) + _x.load(_item, **_kwargs) + self.db[attr] = _x + + self.local_load_adjustments(**_load_args) + + return self + + def keys(self): + return self.db.keys() + + def items(self): + return self.db.items() + + def values(self): + return self.db.values() + + def __contains__(self, item): + return item in self.db + + def get(self, item, default=None): + return self.db.get(item, default) + + def __len__(self): + return len(self.db) + + +def dump_dldict(item, exclude_attributes: Optional[List[str]] = None) -> dict: + res = {} + + for k, v in item.items(): + _class = qualified_name(v.__class__) + if isinstance(v, Message): + res[k] = [_class, v.to_dict()] + else: + res[k] = [_class, v.dump(exclude_attributes=exclude_attributes)] + + return res + + +def load_dldict( + spec: dict, init_args: Optional[dict] = None, load_args: Optional[dict] = None +) -> dict: + db = {} + + for attr, (_item_cls, _item) in spec.items(): + _class = importer(_item_cls) + if issubclass(_class, Message): + db[attr] = _class().from_dict(_item) + else: + if issubclass(_class, ImpExp) and init_args: + _args = {k: v for k, v in init_args.items() if k in _class.init_args} + else: + _args = {} + + db[attr] = _class(**_args).load(_item) + + return db + + +def dump_class_map(item, exclude_attributes: Optional[List[str]] = None) -> dict: + _dump = {} + for key, val in item.items(): + if isinstance(val, str): + _dump[key] = val + else: + _dump[key] = qualified_name(val) + return _dump + + +def load_class_map( + spec: dict, init_args: Optional[dict] = None, load_args: Optional[dict] = None +) -> dict: + _item = {} + for key, val in spec.items(): + _item[key] = importer(val) + return _item diff --git a/src/oidcmsg/message.py b/src/oidcmsg/message.py index f2991ed..f4878b3 100755 --- a/src/oidcmsg/message.py +++ b/src/oidcmsg/message.py @@ -30,6 +30,7 @@ class Message(MutableMapping): """ Represents a basic protocol nessage/item in OAuth2/OIDC """ + c_param = {} c_default = {} c_allowed_values = {} @@ -88,8 +89,7 @@ def to_urlencoded(self, lev=0): if not self.lax: for attribute, (_, req, _ser, _, _) in _spec.items(): if req and attribute not in self._dict: - raise MissingRequiredAttribute("%s" % attribute, - "%s" % self) + raise MissingRequiredAttribute("%s" % attribute, "%s" % self) params = [] @@ -102,7 +102,7 @@ def to_urlencoded(self, lev=0): (_, req, _ser, _deser, null_allowed) = _spec[_key] except (ValueError, KeyError): try: - (_, req, _ser, _, null_allowed) = _spec['*'] + (_, req, _ser, _, null_allowed) = _spec["*"] except KeyError: _ser = None null_allowed = False @@ -115,11 +115,10 @@ def to_urlencoded(self, lev=0): params.append((key, val.encode("utf-8"))) elif isinstance(val, list): if _ser: - params.append((key, str(_ser(val, sformat="urlencoded", - lev=lev)))) + params.append((key, str(_ser(val, sformat="urlencoded", lev=lev)))) else: for item in val: - params.append((key, str(item).encode('utf-8'))) + params.append((key, str(item).encode("utf-8"))) elif isinstance(val, Message): try: _val = json.dumps(_ser(val, sformat="dict", lev=lev + 1)) @@ -196,7 +195,7 @@ def from_urlencoded(self, urlencoded, **kwargs): _info = parse_qs(urlencoded) if len(urlencoded) and _info == {}: - raise FormatError('Wrong format') + raise FormatError("Wrong format") for key, val in _info.items(): try: @@ -207,7 +206,7 @@ def from_urlencoded(self, urlencoded, **kwargs): (typ, _, _, _deser, _) = _spec[_key] except (ValueError, KeyError): try: - (typ, _, _, _deser, _) = _spec['*'] + (typ, _, _, _deser, _) = _spec["*"] except KeyError: if len(val) == 1: val = val[0] @@ -229,7 +228,7 @@ def from_urlencoded(self, urlencoded, **kwargs): else: self._dict[key] = val[0] else: - raise TooManyValues('{}'.format(key)) + raise TooManyValues("{}".format(key)) return self @@ -253,7 +252,7 @@ def to_dict(self, lev=0): _ser = _spec[_key][2] except (ValueError, KeyError): try: - _ser = _spec['*'][2] + _ser = _spec["*"][2] except KeyError: _ser = None @@ -262,8 +261,7 @@ def to_dict(self, lev=0): if isinstance(val, Message): _res[key] = val.to_dict(lev + 1) - elif isinstance(val, list) and isinstance( - next(iter(val or []), None), Message): + elif isinstance(val, list) and isinstance(next(iter(val or []), None), Message): _res[key] = [v.to_dict(lev) for v in val] else: _res[key] = val @@ -296,7 +294,7 @@ def from_dict(self, dictionary, **kwargs): _key = skey.split("#")[0] except ValueError: try: - (vtyp, _, _, _deser, null_allowed) = _spec['*'] + (vtyp, _, _, _deser, null_allowed) = _spec["*"] if val is None: self._dict[key] = val continue @@ -308,7 +306,7 @@ def from_dict(self, dictionary, **kwargs): (vtyp, _, _, _deser, null_allowed) = _spec[_key] except KeyError: try: - (vtyp, _, _, _deser, null_allowed) = _spec['*'] + (vtyp, _, _, _deser, null_allowed) = _spec["*"] if val is None: self._dict[key] = val continue @@ -375,8 +373,7 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed): else: for v in val: if not isinstance(v, vtype): - raise DecodeError( - ERRTXT % (key, "type != %s (%s)" % (vtype, type(v)))) + raise DecodeError(ERRTXT % (key, "type != %s (%s)" % (vtype, type(v)))) self._dict[skey] = val elif isinstance(val, dict): @@ -395,8 +392,7 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed): if vtyp is bool: self._dict[skey] = val else: - raise ValueError( - '"{}", wrong type of value for "{}"'.format(val, skey)) + raise ValueError('"{}", wrong type of value for "{}"'.format(val, skey)) elif isinstance(val, vtyp): # Not necessary to do anything self._dict[skey] = val else: @@ -421,24 +417,17 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed): try: self._dict[skey] = int(val) except (ValueError, TypeError): - raise ValueError( - '"{}", wrong type of value for "{}"'.format(val, - skey)) + raise ValueError('"{}", wrong type of value for "{}"'.format(val, skey)) elif vtyp is bool: - raise ValueError( - '"{}", wrong type of value for "{}"'.format(val, skey)) + raise ValueError('"{}", wrong type of value for "{}"'.format(val, skey)) elif vtyp != type(val): if vtyp == Message: if isinstance(val, (dict, str)): self._dict[skey] = val else: - raise ValueError( - '"{}", wrong type of value for "{}"'.format( - val, skey)) + raise ValueError('"{}", wrong type of value for "{}"'.format(val, skey)) else: - raise ValueError( - '"{}", wrong type of value for "{}"'.format(val, - skey)) + raise ValueError('"{}", wrong type of value for "{}"'.format(val, skey)) def to_json(self, lev=0, indent=None): """ @@ -494,10 +483,10 @@ def from_jwt(self, txt, keyjar, verify=True, **kwargs): """ algarg = {} - if 'encalg' in kwargs: - algarg['alg'] = kwargs['encalg'] - if 'encenc' in kwargs: - algarg['enc'] = kwargs['encenc'] + if "encalg" in kwargs: + algarg["alg"] = kwargs["encalg"] + if "encenc" in kwargs: + algarg["enc"] = kwargs["encenc"] _decryptor = jwe_factory(txt, **algarg) if _decryptor: @@ -505,9 +494,9 @@ def from_jwt(self, txt, keyjar, verify=True, **kwargs): dkeys = keyjar.get_decrypt_key(owner="") - logger.debug('Decrypt class: %s', _decryptor.__class__) + logger.debug("Decrypt class: %s", _decryptor.__class__) _res = _decryptor.decrypt(txt, dkeys) - logger.debug('decrypted message: %s', _res) + logger.debug("decrypted message: %s", _res) if isinstance(_res, tuple): txt = as_unicode(_res[0]) elif isinstance(_res, list) and len(_res) == 2: @@ -516,8 +505,8 @@ def from_jwt(self, txt, keyjar, verify=True, **kwargs): txt = as_unicode(_res) self.jwe_header = _decryptor.jwt.headers - if kwargs.get('sigalg'): - _verifier = jws_factory(txt, alg=kwargs['sigalg']) + if kwargs.get("sigalg"): + _verifier = jws_factory(txt, alg=kwargs["sigalg"]) else: _verifier = jws_factory(txt) @@ -540,8 +529,7 @@ def from_jwt(self, txt, keyjar, verify=True, **kwargs): if "alg" in _header and _header["alg"] != "none": if not key: - raise MissingSigningKey( - "alg=%s" % _header["alg"]) + raise MissingSigningKey("alg=%s" % _header["alg"]) logger.debug("Found signing key.") try: @@ -565,7 +553,7 @@ def __str__(self): :return: A string representation of this class """ - return '{}'.format(self.to_dict()) + return "{}".format(self.to_dict()) @staticmethod def _type_check(typ, _allowed, val, na=False): @@ -738,14 +726,13 @@ def __len__(self): def extra(self): """ - Return the extra parameters that this instance. Extra meaning those + Return the extra parameters that this instance contains. Extra meaning those that are not listed in the c_params specification. :return: The key,value pairs for keys that are not in the c_params specification, """ - return dict([(key, val) for key, val in - self._dict.items() if key not in self.c_param]) + return dict([(key, val) for key, val in self._dict.items() if key not in self.c_param]) def only_extras(self): """ @@ -845,6 +832,7 @@ def value_type(self, parameter): return None + # ============================================================================= @@ -860,6 +848,7 @@ def add_non_standard(msg1, msg2): # ============================================================================= + def list_serializer(vals, sformat="urlencoded", lev=0): if isinstance(vals, str) or not isinstance(vals, list): raise ValueError("Expected list: %s" % vals) @@ -916,7 +905,7 @@ def msg_deser(val, sformat="urlencoded"): def msg_ser(inst, sformat, lev=0): if sformat in ["urlencoded", "json"]: if isinstance(inst, dict): - if sformat == 'json': + if sformat == "json": res = json.dumps(inst) else: res = urlencode([(k, v) for k, v in inst.items()]) @@ -968,28 +957,32 @@ def msg_list_ser(val, sformat="urlencoded", lev=0): SINGLE_REQUIRED_INT = (int, True, None, None, False) SINGLE_REQUIRED_BOOLEAN = (bool, True, None, None, False) -OPTIONAL_LIST_OF_STRINGS = ([str], False, list_serializer, - list_deserializer, False) -REQUIRED_LIST_OF_STRINGS = ([str], True, list_serializer, - list_deserializer, False) -OPTIONAL_LIST_OF_SP_SEP_STRINGS = ([str], False, sp_sep_list_serializer, - sp_sep_list_deserializer, False) -REQUIRED_LIST_OF_SP_SEP_STRINGS = ([str], True, sp_sep_list_serializer, - sp_sep_list_deserializer, False) -SINGLE_OPTIONAL_JSON = (dict, False, json_serializer, json_deserializer, - False) - -SINGLE_REQUIRED_JSON = (dict, True, json_serializer, json_deserializer, - False) - -REQUIRED = [SINGLE_REQUIRED_STRING, REQUIRED_LIST_OF_STRINGS, - REQUIRED_LIST_OF_SP_SEP_STRINGS] +OPTIONAL_LIST_OF_STRINGS = ([str], False, list_serializer, list_deserializer, False) +REQUIRED_LIST_OF_STRINGS = ([str], True, list_serializer, list_deserializer, False) +OPTIONAL_LIST_OF_SP_SEP_STRINGS = ( + [str], + False, + sp_sep_list_serializer, + sp_sep_list_deserializer, + False, +) +REQUIRED_LIST_OF_SP_SEP_STRINGS = ( + [str], + True, + sp_sep_list_serializer, + sp_sep_list_deserializer, + False, +) +SINGLE_OPTIONAL_JSON = (dict, False, json_serializer, json_deserializer, False) + +SINGLE_REQUIRED_JSON = (dict, True, json_serializer, json_deserializer, False) + +REQUIRED = [SINGLE_REQUIRED_STRING, REQUIRED_LIST_OF_STRINGS, REQUIRED_LIST_OF_SP_SEP_STRINGS] OPTIONAL_MESSAGE = (Message, False, msg_ser, msg_deser, False) REQUIRED_MESSAGE = (Message, True, msg_ser, msg_deser, False) -OPTIONAL_LIST_OF_MESSAGES = ([Message], False, msg_list_ser, msg_list_deser, - False) +OPTIONAL_LIST_OF_MESSAGES = ([Message], False, msg_list_ser, msg_list_deser, False) def any_ser(val, sformat="urlencoded", lev=0): diff --git a/src/oidcmsg/oauth2/__init__.py b/src/oidcmsg/oauth2/__init__.py index 30f65b0..b3d9db9 100755 --- a/src/oidcmsg/oauth2/__init__.py +++ b/src/oidcmsg/oauth2/__init__.py @@ -22,7 +22,7 @@ def is_error_message(msg): - if 'error' in msg: + if "error" in msg: return True else: return False @@ -35,10 +35,11 @@ class ResponseMessage(Message): """ The basic error response """ + c_param = { "error": SINGLE_OPTIONAL_STRING, "error_description": SINGLE_OPTIONAL_STRING, - "error_uri": SINGLE_OPTIONAL_STRING + "error_uri": SINGLE_OPTIONAL_STRING, } def verify(self, **kwargs): @@ -55,28 +56,39 @@ class AuthorizationErrorResponse(ResponseMessage): """ Authorization error response. """ + c_param = ResponseMessage.c_param.copy() c_param.update({"state": SINGLE_OPTIONAL_STRING}) c_allowed_values = ResponseMessage.c_allowed_values.copy() - c_allowed_values.update({ - "error": ["invalid_request", - "unauthorized_client", - "access_denied", - "unsupported_response_type", - "invalid_scope", "server_error", - "temporarily_unavailable"] - }) + c_allowed_values.update( + { + "error": [ + "invalid_request", + "unauthorized_client", + "access_denied", + "unsupported_response_type", + "invalid_scope", + "server_error", + "temporarily_unavailable", + ] + } + ) class TokenErrorResponse(ResponseMessage): """ Error response from the token endpoint """ + c_allowed_values = { - "error": ["invalid_request", "invalid_client", - "invalid_grant", "unauthorized_client", - "unsupported_grant_type", - "invalid_scope"] + "error": [ + "invalid_request", + "invalid_client", + "invalid_grant", + "unauthorized_client", + "unsupported_grant_type", + "invalid_scope", + ] } @@ -84,13 +96,14 @@ class AccessTokenRequest(Message): """ An access token request """ + c_param = { "grant_type": SINGLE_REQUIRED_STRING, "code": SINGLE_REQUIRED_STRING, "redirect_uri": SINGLE_REQUIRED_STRING, "client_id": SINGLE_OPTIONAL_STRING, "client_secret": SINGLE_OPTIONAL_STRING, - 'state': SINGLE_OPTIONAL_STRING + "state": SINGLE_OPTIONAL_STRING, } c_default = {"grant_type": "authorization_code"} @@ -99,6 +112,7 @@ class AuthorizationRequest(Message): """ An authorization request """ + c_param = { "response_type": REQUIRED_LIST_OF_SP_SEP_STRINGS, "client_id": SINGLE_REQUIRED_STRING, @@ -119,7 +133,7 @@ def merge(self, request_object, treatement="strict", whitelist=None): result, this is the list to use. """ - if treatement == 'strict': + if treatement == "strict": params = list(self.keys()) # remove all parameters in request that does not appear in request_object for param in params: @@ -143,31 +157,34 @@ class AuthorizationResponse(ResponseMessage): a client_id value provided when calling the verify method. The same with *iss* (issuer). """ + c_param = ResponseMessage.c_param.copy() - c_param.update({ - "code": SINGLE_REQUIRED_STRING, - "state": SINGLE_OPTIONAL_STRING, - 'iss': SINGLE_OPTIONAL_STRING, - 'client_id': SINGLE_OPTIONAL_STRING - }) + c_param.update( + { + "code": SINGLE_REQUIRED_STRING, + "state": SINGLE_OPTIONAL_STRING, + "iss": SINGLE_OPTIONAL_STRING, + "client_id": SINGLE_OPTIONAL_STRING, + } + ) def verify(self, **kwargs): super(AuthorizationResponse, self).verify(**kwargs) - if 'client_id' in self: + if "client_id" in self: try: - if self['client_id'] != kwargs['client_id']: - raise VerificationError('client_id mismatch') + if self["client_id"] != kwargs["client_id"]: + raise VerificationError("client_id mismatch") except KeyError: - logger.info('No client_id to verify against') + logger.info("No client_id to verify against") pass - if 'iss' in self: + if "iss" in self: try: # Issuer URL for the authorization server issuing the response. - if self['iss'] != kwargs['iss']: - raise VerificationError('Issuer mismatch') + if self["iss"] != kwargs["iss"]: + raise VerificationError("Issuer mismatch") except KeyError: - logger.info('No issuer set in the Client config') + logger.info("No issuer set in the Client config") pass return True @@ -177,33 +194,35 @@ class AccessTokenResponse(ResponseMessage): """ Access token response """ + c_param = ResponseMessage.c_param.copy() - c_param.update({ - "access_token": SINGLE_REQUIRED_STRING, - "token_type": SINGLE_REQUIRED_STRING, - "expires_in": SINGLE_OPTIONAL_INT, - "refresh_token": SINGLE_OPTIONAL_STRING, - "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, - "state": SINGLE_OPTIONAL_STRING - }) + c_param.update( + { + "access_token": SINGLE_REQUIRED_STRING, + "token_type": SINGLE_REQUIRED_STRING, + "expires_in": SINGLE_OPTIONAL_INT, + "refresh_token": SINGLE_OPTIONAL_STRING, + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, + "state": SINGLE_OPTIONAL_STRING, + } + ) class NoneResponse(ResponseMessage): c_param = ResponseMessage.c_param.copy() - c_param.update({ - "state": SINGLE_OPTIONAL_STRING - }) + c_param.update({"state": SINGLE_OPTIONAL_STRING}) class ROPCAccessTokenRequest(Message): """ Resource Owner Password Credentials Grant flow access token request """ + c_param = { "grant_type": SINGLE_REQUIRED_STRING, "username": SINGLE_OPTIONAL_STRING, "password": SINGLE_OPTIONAL_STRING, - "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, } @@ -211,10 +230,8 @@ class CCAccessTokenRequest(Message): """ Client Credential grant flow access token request """ - c_param = { - "grant_type": SINGLE_REQUIRED_STRING, - "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS - } + + c_param = {"grant_type": SINGLE_REQUIRED_STRING, "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS} c_default = {"grant_type": "client_credentials"} c_allowed_values = {"grant_type": ["client_credentials"]} @@ -223,12 +240,13 @@ class RefreshAccessTokenRequest(Message): """ Access token refresh request """ + c_param = { "grant_type": SINGLE_REQUIRED_STRING, "refresh_token": SINGLE_REQUIRED_STRING, "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, "client_id": SINGLE_OPTIONAL_STRING, - "client_secret": SINGLE_OPTIONAL_STRING + "client_secret": SINGLE_OPTIONAL_STRING, } c_default = {"grant_type": "refresh_token"} c_allowed_values = {"grant_type": ["refresh_token"]} @@ -242,27 +260,29 @@ class ASConfigurationResponse(Message): """ Authorization Server configuration response """ + c_param = ResponseMessage.c_param.copy() - c_param.update({ - "issuer": SINGLE_REQUIRED_STRING, - "authorization_endpoint": SINGLE_OPTIONAL_STRING, - "token_endpoint": SINGLE_OPTIONAL_STRING, - "jwks_uri": SINGLE_OPTIONAL_STRING, - "registration_endpoint": SINGLE_OPTIONAL_STRING, - "scopes_supported": OPTIONAL_LIST_OF_STRINGS, - "response_types_supported": REQUIRED_LIST_OF_STRINGS, - "response_modes_supported": OPTIONAL_LIST_OF_STRINGS, - "grant_types_supported": REQUIRED_LIST_OF_STRINGS, - "token_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, - "token_endpoint_auth_signing_alg_values_supported": - OPTIONAL_LIST_OF_STRINGS, - "service_documentation": SINGLE_OPTIONAL_STRING, - "ui_locales_supported": OPTIONAL_LIST_OF_STRINGS, - "op_policy_uri": SINGLE_OPTIONAL_STRING, - "op_tos_uri": SINGLE_OPTIONAL_STRING, - 'revocation_endpoint': SINGLE_OPTIONAL_STRING, - 'introspection_endpoint': SINGLE_OPTIONAL_STRING, - }) + c_param.update( + { + "issuer": SINGLE_REQUIRED_STRING, + "authorization_endpoint": SINGLE_OPTIONAL_STRING, + "token_endpoint": SINGLE_OPTIONAL_STRING, + "jwks_uri": SINGLE_OPTIONAL_STRING, + "registration_endpoint": SINGLE_OPTIONAL_STRING, + "scopes_supported": OPTIONAL_LIST_OF_STRINGS, + "response_types_supported": REQUIRED_LIST_OF_STRINGS, + "response_modes_supported": OPTIONAL_LIST_OF_STRINGS, + "grant_types_supported": REQUIRED_LIST_OF_STRINGS, + "token_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, + "token_endpoint_auth_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "service_documentation": SINGLE_OPTIONAL_STRING, + "ui_locales_supported": OPTIONAL_LIST_OF_STRINGS, + "op_policy_uri": SINGLE_OPTIONAL_STRING, + "op_tos_uri": SINGLE_OPTIONAL_STRING, + "revocation_endpoint": SINGLE_OPTIONAL_STRING, + "introspection_endpoint": SINGLE_OPTIONAL_STRING, + } + ) c_default = {"version": "3.0"} @@ -323,10 +343,7 @@ class TokenExchangeResponse(Message): class JWTSecuredAuthorizationRequest(AuthorizationRequest): c_param = AuthorizationRequest.c_param.copy() - c_param.update({ - "request": SINGLE_OPTIONAL_STRING, - "request_uri": SINGLE_OPTIONAL_STRING - }) + c_param.update({"request": SINGLE_OPTIONAL_STRING, "request_uri": SINGLE_OPTIONAL_STRING}) def verify(self, **kwargs): if "request" in self: @@ -335,15 +352,14 @@ def verify(self, **kwargs): del self[_vc_name] args = {} - for arg in ["keyjar", "opponent_id", "sender", "alg", "encalg", - "encenc"]: + for arg in ["keyjar", "opponent_id", "sender", "alg", "encalg", "encenc"]: try: args[arg] = kwargs[arg] except KeyError: pass _req = AuthorizationRequest().from_jwt(str(self["request"]), **args) - self.merge(_req, 'strict') + self.merge(_req, "strict") self[_vc_name] = _req elif "request_uri" not in self: raise MissingAttribute("One of request or request_uri must be present") @@ -353,9 +369,7 @@ def verify(self, **kwargs): class PushedAuthorizationRequest(AuthorizationRequest): c_param = AuthorizationRequest.c_param.copy() - c_param.update({ - "request": SINGLE_OPTIONAL_STRING - }) + c_param.update({"request": SINGLE_OPTIONAL_STRING}) def verify(self, **kwargs): if "request" in self: @@ -364,8 +378,7 @@ def verify(self, **kwargs): del self[_vc_name] args = {} - for arg in ["keyjar", "opponent_id", "sender", "alg", "encalg", - "encenc"]: + for arg in ["keyjar", "opponent_id", "sender", "alg", "encalg", "encenc"]: try: args[arg] = kwargs[arg] except KeyError: @@ -388,7 +401,7 @@ class SecurityEventToken(Message): "exp": SINGLE_OPTIONAL_INT, "events": SINGLE_OPTIONAL_JSON, "txt": SINGLE_OPTIONAL_STRING, - "toe": SINGLE_OPTIONAL_INT + "toe": SINGLE_OPTIONAL_INT, } diff --git a/src/oidcmsg/oidc/__init__.py b/src/oidcmsg/oidc/__init__.py index 0626247..8488469 100755 --- a/src/oidcmsg/oidc/__init__.py +++ b/src/oidcmsg/oidc/__init__.py @@ -41,7 +41,7 @@ from oidcmsg.oauth2 import ResponseMessage from oidcmsg.time_util import utc_time_sans_frac -__author__ = 'Roland Hedberg' +__author__ = "Roland Hedberg" logger = logging.getLogger(__name__) @@ -66,18 +66,18 @@ class IATError(VerificationError): def deserialize_from_one_of(val, msgtype, sformat): if sformat in ["dict", "json"]: - flist = ['json', 'urlencoded'] + flist = ["json", "urlencoded"] if not isinstance(val, str): val = json.dumps(val) else: - flist = ['urlencoded', 'json'] + flist = ["urlencoded", "json"] for _format in flist: try: return msgtype().deserialize(val, _format) except FormatError: pass - raise FormatError('Unexpected format') + raise FormatError("Unexpected format") def json_ser(val, sformat=None, lev=0): @@ -179,7 +179,7 @@ def claims_request_deser(val, sformat="json"): if not isinstance(val, str): val = json.dumps(val) sformat = "json" - elif sformat == 'dict': + elif sformat == "dict": if isinstance(val, str): val = json.loads(val) @@ -209,10 +209,8 @@ def dict_deser(val, sformat="json"): SINGLE_OPTIONAL_IDTOKEN = (Message, False, msg_ser, None, False) SINGLE_REQUIRED_IDTOKEN = (Message, True, msg_ser, None, False) -SINGLE_OPTIONAL_REGISTRATION_REQUEST = (Message, False, msg_ser, - registration_request_deser, False) -SINGLE_OPTIONAL_CLAIMSREQ = (Message, False, msg_ser_json, claims_request_deser, - False) +SINGLE_OPTIONAL_REGISTRATION_REQUEST = (Message, False, msg_ser, registration_request_deser, False) +SINGLE_OPTIONAL_CLAIMSREQ = (Message, False, msg_ser_json, claims_request_deser, False) SINGLE_OPTIONAL_DICT = (dict, False, msg_ser_json, dict_deser, False) @@ -220,7 +218,7 @@ def dict_deser(val, sformat="json"): SCOPE_CHARSET = [] -for char in ['\x21', ('\x23', '\x5b'), ('\x5d', '\x7E')]: +for char in ["\x21", ("\x23", "\x5b"), ("\x5d", "\x7E")]: if isinstance(char, tuple): c = char[0] while c <= char[1]: @@ -238,12 +236,24 @@ def check_char_set(string, allowed): # ----------------------------------------------------------------------------- -ID_TOKEN_VERIFY_ARGS = ['keyjar', 'verify', 'encalg', 'encenc', 'sigalg', - 'issuer', 'allow_missing_kid', 'no_kid_issuer', - 'trusting', 'skew', 'nonce_storage_time', 'client_id', - 'allow_sign_alg_none', 'allowed_sign_alg'] - -CLAIMS_WITH_VERIFIED = ['id_token', 'id_token_hint', 'request'] +ID_TOKEN_VERIFY_ARGS = [ + "keyjar", + "verify", + "encalg", + "encenc", + "sigalg", + "issuer", + "allow_missing_kid", + "no_kid_issuer", + "trusting", + "skew", + "nonce_storage_time", + "client_id", + "allow_sign_alg_none", + "allowed_sign_alg", +] + +CLAIMS_WITH_VERIFIED = ["id_token", "id_token_hint", "request"] def clear_verified_claims(msg): @@ -264,7 +274,7 @@ class TokenErrorResponse(oauth2.TokenErrorResponse): pass -def verify_id_token(msg, check_hash=False, claim='id_token', **kwargs): +def verify_id_token(msg, check_hash=False, claim="id_token", **kwargs): # Try to decode the JWT, checks the signature args = {} for arg in ID_TOKEN_VERIFY_ARGS: @@ -275,34 +285,35 @@ def verify_id_token(msg, check_hash=False, claim='id_token', **kwargs): _jws = jws_factory(msg[claim]) if not _jws: - raise ValueError('{} not a signed JWT'.format(claim)) + raise ValueError("{} not a signed JWT".format(claim)) - if _jws.jwt.headers['alg'] == 'none': + if _jws.jwt.headers["alg"] == "none": _signed = False _sign_alg = kwargs.get("sigalg") if _sign_alg == "none": _allowed = True else: # There might or might not be a specified signing alg - if kwargs.get('allow_sign_alg_none', False) is False: - logger.info('Signing algorithm None not allowed') - raise UnsupportedAlgorithm('Signing algorithm None not allowed') + if kwargs.get("allow_sign_alg_none", False) is False: + logger.info("Signing algorithm None not allowed") + raise UnsupportedAlgorithm("Signing algorithm None not allowed") else: _signed = True if "allowed_sign_alg" in kwargs: - if _jws.jwt.headers['alg'] != kwargs["allowed_sign_alg"]: + if _jws.jwt.headers["alg"] != kwargs["allowed_sign_alg"]: _msg = "Wrong token signing algorithm, {} != {}".format( - _jws.jwt.headers['alg'], kwargs["allowed_sign_alg"]) + _jws.jwt.headers["alg"], kwargs["allowed_sign_alg"] + ) logger.error(_msg) raise UnsupportedAlgorithm(_msg) _body = _jws.jwt.payload() - if _signed and 'keyjar' in kwargs: + if _signed and "keyjar" in kwargs: try: - if _body['iss'] not in kwargs['keyjar']: - logger.info('KeyJar issuers: {}'.format(kwargs['keyjar'])) - raise ValueError('Unknown issuer: "{}"'.format(_body['iss'])) + if _body["iss"] not in kwargs["keyjar"]: + logger.info("KeyJar issuers: {}".format(kwargs["keyjar"])) + raise ValueError('Unknown issuer: "{}"'.format(_body["iss"])) except KeyError: - raise MissingRequiredAttribute('iss') + raise MissingRequiredAttribute("iss") idt = IdToken().from_jwt(str(msg[claim]), **args) if not idt.verify(**kwargs): @@ -314,22 +325,18 @@ def verify_id_token(msg, check_hash=False, claim='id_token', **kwargs): if "access_token" in msg: if "at_hash" not in idt: - raise MissingRequiredAttribute("Missing at_hash property", - idt) - if idt["at_hash"] != left_hash(msg["access_token"], - hfunc): - raise AtHashError( - "Failed to verify access_token hash", idt) + raise MissingRequiredAttribute("Missing at_hash property", idt) + if idt["at_hash"] != left_hash(msg["access_token"], hfunc): + raise AtHashError("Failed to verify access_token hash", idt) if "code" in msg: if "c_hash" not in idt: - raise MissingRequiredAttribute("Missing c_hash property", - idt) + raise MissingRequiredAttribute("Missing c_hash property", idt) if idt["c_hash"] != left_hash(msg["code"], hfunc): raise CHashError("Failed to verify code hash", idt) msg[verified_claim_name(claim)] = idt - logger.info('Verified {}: {}'.format(claim, idt.to_dict())) + logger.info("Verified {}: {}".format(claim, idt.to_dict())) return True @@ -356,19 +363,20 @@ class UserInfoRequest(Message): } -class AuthorizationResponse(oauth2.AuthorizationResponse, - oauth2.AccessTokenResponse): +class AuthorizationResponse(oauth2.AuthorizationResponse, oauth2.AccessTokenResponse): c_param = oauth2.AuthorizationResponse.c_param.copy() c_param.update(oauth2.AccessTokenResponse.c_param) - c_param.update({ - "code": SINGLE_OPTIONAL_STRING, - # "nonce": SINGLE_OPTIONAL_STRING, - "access_token": SINGLE_OPTIONAL_STRING, - "token_type": SINGLE_OPTIONAL_STRING, - "id_token": SINGLE_OPTIONAL_IDTOKEN, - # Below is REQUIRED if doing session management - "session_state": SINGLE_OPTIONAL_STRING - }) + c_param.update( + { + "code": SINGLE_OPTIONAL_STRING, + # "nonce": SINGLE_OPTIONAL_STRING, + "access_token": SINGLE_OPTIONAL_STRING, + "token_type": SINGLE_OPTIONAL_STRING, + "id_token": SINGLE_OPTIONAL_IDTOKEN, + # Below is REQUIRED if doing session management + "session_state": SINGLE_OPTIONAL_STRING, + } + ) def verify(self, **kwargs): super(AuthorizationResponse, self).verify(**kwargs) @@ -388,17 +396,20 @@ def verify(self, **kwargs): class AuthorizationErrorResponse(oauth2.AuthorizationErrorResponse): - c_allowed_values = oauth2.AuthorizationErrorResponse.c_allowed_values \ - .copy() - c_allowed_values["error"].extend(["interaction_required", - "login_required", - "session_selection_required", - "consent_required", - "invalid_request_uri", - "invalid_request_object", - "registration_not_supported", - "request_not_supported", - "request_uri_not_supported"]) + c_allowed_values = oauth2.AuthorizationErrorResponse.c_allowed_values.copy() + c_allowed_values["error"].extend( + [ + "interaction_required", + "login_required", + "session_selection_required", + "consent_required", + "invalid_request_uri", + "invalid_request_object", + "registration_not_supported", + "request_not_supported", + "request_uri_not_supported", + ] + ) class AuthorizationRequest(oauth2.AuthorizationRequest): @@ -425,10 +436,12 @@ class AuthorizationRequest(oauth2.AuthorizationRequest): } ) c_allowed_values = oauth2.AuthorizationRequest.c_allowed_values.copy() - c_allowed_values.update({ - "display": ["page", "popup", "touch", "wap"], - "prompt": ["none", "login", "consent", "select_account"] - }) + c_allowed_values.update( + { + "display": ["page", "popup", "touch", "wap"], + "prompt": ["none", "login", "consent", "select_account"], + } + ) def verify(self, **kwargs): """Authorization Request parameters that are OPTIONAL in the OAuth 2.0 @@ -444,8 +457,7 @@ def verify(self, **kwargs): clear_verified_claims(self) args = {} - for arg in ["keyjar", "opponent_id", "sender", "alg", "encalg", - "encenc"]: + for arg in ["keyjar", "opponent_id", "sender", "alg", "encalg", "encenc"]: try: args[arg] = kwargs[arg] except KeyError: @@ -464,7 +476,7 @@ def verify(self, **kwargs): if key in self: if self[key] != val: # log but otherwise ignore - logger.warning('{} != {}'.format(self[key], val)) + logger.warning("{} != {}".format(self[key], val)) # remove all claims _keys = list(self.keys()) @@ -491,10 +503,8 @@ def verify(self, **kwargs): raise MissingRequiredAttribute("Nonce missing", self) else: try: - if self['nonce'] != kwargs['nonce']: - raise ValueError( - 'Nonce in id_token not matching nonce in authz ' - 'request') + if self["nonce"] != kwargs["nonce"]: + raise ValueError("Nonce in id_token not matching nonce in authz " "request") except KeyError: pass @@ -507,22 +517,22 @@ def verify(self, **kwargs): if "prompt" in self: if "none" in self["prompt"] and len(self["prompt"]) > 1: - raise InvalidRequest("prompt none combined with other value", - self) + raise InvalidRequest("prompt none combined with other value", self) return True class AccessTokenRequest(oauth2.AccessTokenRequest): c_param = oauth2.AccessTokenRequest.c_param.copy() - c_param.update({ - "client_assertion_type": SINGLE_OPTIONAL_STRING, - "client_assertion": SINGLE_OPTIONAL_STRING - }) + c_param.update( + { + "client_assertion_type": SINGLE_OPTIONAL_STRING, + "client_assertion": SINGLE_OPTIONAL_STRING, + } + ) c_default = {"grant_type": "authorization_code"} c_allowed_values = { - "client_assertion_type": [ - "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"], + "client_assertion_type": ["urn:ietf:params:oauth:client-assertion-type:jwt-bearer"], } @@ -533,7 +543,7 @@ class AddressClaim(Message): "locality": SINGLE_OPTIONAL_STRING, "region": SINGLE_OPTIONAL_STRING, "postal_code": SINGLE_OPTIONAL_STRING, - "country": SINGLE_OPTIONAL_STRING + "country": SINGLE_OPTIONAL_STRING, } @@ -562,8 +572,9 @@ class OpenIDSchema(ResponseMessage): "address": OPTIONAL_ADDRESS, "updated_at": SINGLE_OPTIONAL_INT, "_claim_names": OPTIONAL_MESSAGE, - "_claim_sources": OPTIONAL_MESSAGE - }) + "_claim_sources": OPTIONAL_MESSAGE, + } + ) def verify(self, **kwargs): super(OpenIDSchema, self).verify(**kwargs) @@ -628,12 +639,12 @@ class RegistrationRequest(Message): "backchannel_logout_uri": SINGLE_OPTIONAL_STRING, "backchannel_logout_session_supported": SINGLE_OPTIONAL_BOOLEAN, "federation_type": OPTIONAL_LIST_OF_STRINGS, - "organization_name": SINGLE_OPTIONAL_STRING + "organization_name": SINGLE_OPTIONAL_STRING, } c_default = {"application_type": "web", "response_types": ["code"]} c_allowed_values = { "application_type": ["native", "web"], - "subject_type": ["public", "pairwise"] + "subject_type": ["public", "pairwise"], } def verify(self, **kwargs): @@ -641,11 +652,13 @@ def verify(self, **kwargs): if "initiate_login_uri" in self: if not self["initiate_login_uri"].startswith("https:"): - raise ValueError('Wrong scheme') + raise ValueError("Wrong scheme") - for param in ["request_object_encryption", - "id_token_encrypted_response", - "userinfo_encrypted_response"]: + for param in [ + "request_object_encryption", + "id_token_encrypted_response", + "userinfo_encrypted_response", + ]: alg_param = "%s_alg" % param enc_param = "%s_enc" % param if alg_param in self: @@ -655,7 +668,7 @@ def verify(self, **kwargs): # both or none if enc_param in self: if alg_param not in self: - raise MissingRequiredAttribute('alg_param') + raise MissingRequiredAttribute("alg_param") if "token_endpoint_auth_signing_alg" in self: if self["token_endpoint_auth_signing_alg"] == "none": @@ -668,6 +681,7 @@ class RegistrationResponse(ResponseMessage): """ Response to client_register registration requests """ + c_param = ResponseMessage.c_param.copy() c_param.update( { @@ -676,8 +690,9 @@ class RegistrationResponse(ResponseMessage): "registration_access_token": SINGLE_OPTIONAL_STRING, "registration_client_uri": SINGLE_OPTIONAL_STRING, "client_id_issued_at": SINGLE_OPTIONAL_INT, - "client_secret_expires_at": SINGLE_OPTIONAL_INT - }) + "client_secret_expires_at": SINGLE_OPTIONAL_INT, + } + ) c_param.update(RegistrationRequest.c_param) def verify(self, **kwargs): @@ -692,40 +707,45 @@ def verify(self, **kwargs): has_reg_uri = "registration_client_uri" in self has_reg_at = "registration_access_token" in self if has_reg_uri != has_reg_at: - raise VerificationError(( - "Only one of registration_client_uri" - " and registration_access_token present"), self) + raise VerificationError( + ("Only one of registration_client_uri" " and registration_access_token present"), + self, + ) return True class ClientRegistrationErrorResponse(oauth2.ResponseMessage): c_allowed_values = { - "error": ["invalid_redirect_uri", - "invalid_client_metadata", - "invalid_configuration_parameter"] + "error": [ + "invalid_redirect_uri", + "invalid_client_metadata", + "invalid_configuration_parameter", + ] } class IdToken(OpenIDSchema): c_param = OpenIDSchema.c_param.copy() - c_param.update({ - "iss": SINGLE_REQUIRED_STRING, - "sub": SINGLE_REQUIRED_STRING, - "aud": REQUIRED_LIST_OF_STRINGS, # Array of strings or string - "exp": SINGLE_REQUIRED_INT, - "iat": SINGLE_REQUIRED_INT, - "auth_time": SINGLE_OPTIONAL_INT, - "nonce": SINGLE_OPTIONAL_STRING, - "at_hash": SINGLE_OPTIONAL_STRING, - "c_hash": SINGLE_OPTIONAL_STRING, - "acr": SINGLE_OPTIONAL_STRING, - "amr": OPTIONAL_LIST_OF_STRINGS, - "azp": SINGLE_OPTIONAL_STRING, - "sub_jwk": SINGLE_OPTIONAL_STRING, - "sid": SINGLE_OPTIONAL_STRING - }) - hashable = {'access_token': 'at_hash', 'code': 'c_hash'} + c_param.update( + { + "iss": SINGLE_REQUIRED_STRING, + "sub": SINGLE_REQUIRED_STRING, + "aud": REQUIRED_LIST_OF_STRINGS, # Array of strings or string + "exp": SINGLE_REQUIRED_INT, + "iat": SINGLE_REQUIRED_INT, + "auth_time": SINGLE_OPTIONAL_INT, + "nonce": SINGLE_OPTIONAL_STRING, + "at_hash": SINGLE_OPTIONAL_STRING, + "c_hash": SINGLE_OPTIONAL_STRING, + "acr": SINGLE_OPTIONAL_STRING, + "amr": OPTIONAL_LIST_OF_STRINGS, + "azp": SINGLE_OPTIONAL_STRING, + "sub_jwk": SINGLE_OPTIONAL_STRING, + "sid": SINGLE_OPTIONAL_STRING, + } + ) + hashable = {"access_token": "at_hash", "code": "c_hash"} def val_hash(self, alg): halg = "HS%s" % alg[-3:] @@ -739,15 +759,15 @@ def val_hash(self, alg): del self[attr] def pack_init(self, lifetime=0): - self['iat'] = utc_time_sans_frac() + self["iat"] = utc_time_sans_frac() if lifetime: - self['exp'] = self['iat'] + lifetime + self["exp"] = self["iat"] + lifetime - def pack(self, alg='', **kwargs): + def pack(self, alg="", **kwargs): self.val_hash(alg) - if 'lifetime' in kwargs: - self.pack_init(kwargs['lifetime']) + if "lifetime" in kwargs: + self.pack_init(kwargs["lifetime"]) else: self.pack_init() @@ -767,9 +787,8 @@ def verify(self, **kwargs): super(IdToken, self).verify(**kwargs) try: - if kwargs['iss'] != self['iss']: - raise IssuerMismatch( - '{} != {}'.format(kwargs['iss'], self['iss'])) + if kwargs["iss"] != self["iss"]: + raise IssuerMismatch("{} != {}".format(kwargs["iss"], self["iss"])) except KeyError: pass @@ -778,61 +797,58 @@ def verify(self, **kwargs): # check that I'm among the recipients if kwargs["client_id"] not in self["aud"]: raise NotForMe( - "{} not in aud:{}".format(kwargs["client_id"], - self["aud"]), self) + "{} not in aud:{}".format(kwargs["client_id"], self["aud"]), self + ) # Then azp has to be present and be one of the aud values if len(self["aud"]) > 1: if "azp" in self: if self["azp"] not in self["aud"]: - raise VerificationError( - "Mismatch between azp and aud claims", self) + raise VerificationError("Mismatch between azp and aud claims", self) else: raise VerificationError("azp missing", self) if "azp" in self: if "client_id" in kwargs: if kwargs["client_id"] != self["azp"]: - raise NotForMe( - "{} != azp:{}".format(kwargs["client_id"], - self["azp"]), self) + raise NotForMe("{} != azp:{}".format(kwargs["client_id"], self["azp"]), self) _now = time_util.utc_time_sans_frac() try: - _skew = kwargs['skew'] + _skew = kwargs["skew"] except KeyError: _skew = 0 try: - _exp = self['exp'] + _exp = self["exp"] except KeyError: - raise MissingRequiredAttribute('exp') + raise MissingRequiredAttribute("exp") else: if (_now - _skew) > _exp: - raise EXPError('Invalid expiration time') + raise EXPError("Invalid expiration time") try: - _storage_time = kwargs['nonce_storage_time'] + _storage_time = kwargs["nonce_storage_time"] except KeyError: _storage_time = NONCE_STORAGE_TIME try: - _iat = self['iat'] + _iat = self["iat"] except KeyError: - raise MissingRequiredAttribute('iat') + raise MissingRequiredAttribute("iat") else: if (_iat + _storage_time) < (_now - _skew): - raise IATError('Issued too long ago') + raise IATError("Issued too long ago") elif _iat > _now + _skew: - raise IATError('Issued sometime in the future') + raise IATError("Issued sometime in the future") if _exp < _iat: - raise IATError('Expiration time can not be earlier the issued at') + raise IATError("Expiration time can not be earlier the issued at") - if 'nonce' in kwargs and 'nonce' in self: - if kwargs['nonce'] != self['nonce']: - raise ValueError('Not the same nonce') + if "nonce" in kwargs and "nonce" in self: + if kwargs["nonce"] != self["nonce"]: + raise ValueError("Not the same nonce") return True @@ -855,64 +871,62 @@ class OpenIDRequest(AuthorizationRequest): class ProviderConfigurationResponse(ResponseMessage): c_param = ResponseMessage.c_param.copy() - c_param.update({ - "issuer": SINGLE_REQUIRED_STRING, - "authorization_endpoint": SINGLE_REQUIRED_STRING, - "token_endpoint": SINGLE_OPTIONAL_STRING, - "userinfo_endpoint": SINGLE_OPTIONAL_STRING, - "jwks_uri": SINGLE_REQUIRED_STRING, - "registration_endpoint": SINGLE_OPTIONAL_STRING, - "scopes_supported": OPTIONAL_LIST_OF_STRINGS, - "response_types_supported": REQUIRED_LIST_OF_STRINGS, - "response_modes_supported": OPTIONAL_LIST_OF_STRINGS, - "grant_types_supported": OPTIONAL_LIST_OF_STRINGS, - "acr_values_supported": OPTIONAL_LIST_OF_STRINGS, - "subject_types_supported": REQUIRED_LIST_OF_STRINGS, - "id_token_signing_alg_values_supported": REQUIRED_LIST_OF_STRINGS, - "id_token_encryption_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, - "id_token_encryption_enc_values_supported": OPTIONAL_LIST_OF_STRINGS, - "userinfo_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, - "userinfo_encryption_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, - "userinfo_encryption_enc_values_supported": OPTIONAL_LIST_OF_STRINGS, - "request_object_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, - "request_object_encryption_alg_values_supported": - OPTIONAL_LIST_OF_STRINGS, - "request_object_encryption_enc_values_supported": - OPTIONAL_LIST_OF_STRINGS, - "token_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, - "token_endpoint_auth_signing_alg_values_supported": - OPTIONAL_LIST_OF_STRINGS, - "display_values_supported": OPTIONAL_LIST_OF_STRINGS, - "claim_types_supported": OPTIONAL_LIST_OF_STRINGS, - "claims_supported": OPTIONAL_LIST_OF_STRINGS, - "service_documentation": SINGLE_OPTIONAL_STRING, - "claims_locales_supported": OPTIONAL_LIST_OF_STRINGS, - "ui_locales_supported": OPTIONAL_LIST_OF_STRINGS, - "claims_parameter_supported": SINGLE_OPTIONAL_BOOLEAN, - "request_parameter_supported": SINGLE_OPTIONAL_BOOLEAN, - "request_uri_parameter_supported": SINGLE_OPTIONAL_BOOLEAN, - "require_request_uri_registration": SINGLE_OPTIONAL_BOOLEAN, - "op_policy_uri": SINGLE_OPTIONAL_STRING, - "op_tos_uri": SINGLE_OPTIONAL_STRING, - "check_session_iframe": SINGLE_OPTIONAL_STRING, - "end_session_endpoint": SINGLE_OPTIONAL_STRING, - "frontchannel_logout_supported": SINGLE_OPTIONAL_BOOLEAN, - "frontchannel_logout_session_supported": SINGLE_OPTIONAL_BOOLEAN, - "backchannel_logout_supported": SINGLE_OPTIONAL_BOOLEAN, - "backchannel_logout_session_supported": SINGLE_OPTIONAL_BOOLEAN - # "jwk_encryption_url": SINGLE_OPTIONAL_STRING, - # "x509_url": SINGLE_REQUIRED_STRING, - # "x509_encryption_url": SINGLE_OPTIONAL_STRING, - }) + c_param.update( + { + "issuer": SINGLE_REQUIRED_STRING, + "authorization_endpoint": SINGLE_REQUIRED_STRING, + "token_endpoint": SINGLE_OPTIONAL_STRING, + "userinfo_endpoint": SINGLE_OPTIONAL_STRING, + "jwks_uri": SINGLE_REQUIRED_STRING, + "registration_endpoint": SINGLE_OPTIONAL_STRING, + "scopes_supported": OPTIONAL_LIST_OF_STRINGS, + "response_types_supported": REQUIRED_LIST_OF_STRINGS, + "response_modes_supported": OPTIONAL_LIST_OF_STRINGS, + "grant_types_supported": OPTIONAL_LIST_OF_STRINGS, + "acr_values_supported": OPTIONAL_LIST_OF_STRINGS, + "subject_types_supported": REQUIRED_LIST_OF_STRINGS, + "id_token_signing_alg_values_supported": REQUIRED_LIST_OF_STRINGS, + "id_token_encryption_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "id_token_encryption_enc_values_supported": OPTIONAL_LIST_OF_STRINGS, + "userinfo_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "userinfo_encryption_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "userinfo_encryption_enc_values_supported": OPTIONAL_LIST_OF_STRINGS, + "request_object_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "request_object_encryption_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "request_object_encryption_enc_values_supported": OPTIONAL_LIST_OF_STRINGS, + "token_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, + "token_endpoint_auth_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "display_values_supported": OPTIONAL_LIST_OF_STRINGS, + "claim_types_supported": OPTIONAL_LIST_OF_STRINGS, + "claims_supported": OPTIONAL_LIST_OF_STRINGS, + "service_documentation": SINGLE_OPTIONAL_STRING, + "claims_locales_supported": OPTIONAL_LIST_OF_STRINGS, + "ui_locales_supported": OPTIONAL_LIST_OF_STRINGS, + "claims_parameter_supported": SINGLE_OPTIONAL_BOOLEAN, + "request_parameter_supported": SINGLE_OPTIONAL_BOOLEAN, + "request_uri_parameter_supported": SINGLE_OPTIONAL_BOOLEAN, + "require_request_uri_registration": SINGLE_OPTIONAL_BOOLEAN, + "op_policy_uri": SINGLE_OPTIONAL_STRING, + "op_tos_uri": SINGLE_OPTIONAL_STRING, + "check_session_iframe": SINGLE_OPTIONAL_STRING, + "end_session_endpoint": SINGLE_OPTIONAL_STRING, + "frontchannel_logout_supported": SINGLE_OPTIONAL_BOOLEAN, + "frontchannel_logout_session_supported": SINGLE_OPTIONAL_BOOLEAN, + "backchannel_logout_supported": SINGLE_OPTIONAL_BOOLEAN, + "backchannel_logout_session_supported": SINGLE_OPTIONAL_BOOLEAN + # "jwk_encryption_url": SINGLE_OPTIONAL_STRING, + # "x509_url": SINGLE_REQUIRED_STRING, + # "x509_encryption_url": SINGLE_OPTIONAL_STRING, + } + ) c_default = { "version": "3.0", - "token_endpoint_auth_methods_supported": [ - "client_secret_basic"], + "token_endpoint_auth_methods_supported": ["client_secret_basic"], "claims_parameter_supported": False, "request_parameter_supported": False, "request_uri_parameter_supported": True, "require_request_uri_registration": True, - "grant_types_supported": ["authorization_code", "implicit"] + "grant_types_supported": ["authorization_code", "implicit"], } def verify(self, **kwargs): @@ -925,29 +939,33 @@ def verify(self, **kwargs): check_char_set(scope, SCOPE_CHARSET) parts = urlparse(self["issuer"]) - if 'allow_http' in kwargs: + if "allow_http" in kwargs: pass elif parts.scheme != "https": raise SchemeError("Not HTTPS") # The parameter is optional - if "token_endpoint_auth_signing_alg_values_supported" in self and "none" in self[ - "token_endpoint_auth_signing_alg_values_supported"]: + if ( + "token_endpoint_auth_signing_alg_values_supported" in self + and "none" in self["token_endpoint_auth_signing_alg_values_supported"] + ): raise ValueError( "The value none must not be used for " "token_endpoint_auth_signing_alg_values_supported" ) if "RS256" not in self["id_token_signing_alg_values_supported"]: - raise ValueError('RS256 missing from id_token_signing_alg_values_supported') + raise ValueError("RS256 missing from id_token_signing_alg_values_supported") if not parts.query and not parts.fragment: pass else: - raise ValueError('Issuer ID invalid') + raise ValueError("Issuer ID invalid") - if any("code" in rt for rt in self[ - "response_types_supported"]) and "token_endpoint" not in self: + if ( + any("code" in rt for rt in self["response_types_supported"]) + and "token_endpoint" not in self + ): raise MissingRequiredAttribute("token_endpoint") return True @@ -971,48 +989,48 @@ def verify(self, **kwargs): _now = utc_time_sans_frac() try: - _skew = kwargs['skew'] + _skew = kwargs["skew"] except KeyError: _skew = 0 try: - _exp = self['exp'] + _exp = self["exp"] except KeyError: pass else: if (_now - _skew) > _exp: - raise EXPError('Invalid expiration time') + raise EXPError("Invalid expiration time") try: - _iat = self['iat'] + _iat = self["iat"] except KeyError: pass else: if _iat > (_now + _skew): - raise EXPError('Invalid issued-at time') + raise EXPError("Invalid issued-at time") try: - _nbf = self['nbf'] + _nbf = self["nbf"] except KeyError: pass else: if _nbf > (_now - _skew): - raise EXPError('Not valid yet') + raise EXPError("Not valid yet") try: - _aud = self['aud'] + _aud = self["aud"] except KeyError: pass else: try: - if kwargs['aud'] not in _aud: - raise NotForMe('Not among intended audience') + if kwargs["aud"] not in _aud: + raise NotForMe("Not among intended audience") except KeyError: pass - if 'iss' in kwargs and 'iss' in self: - if kwargs['iss'] != self['iss']: - raise ValueError('Wrong issuer') + if "iss" in kwargs and "iss" in self: + if kwargs["iss"] != self["iss"]: + raise ValueError("Wrong issuer") return True @@ -1043,28 +1061,25 @@ def jwt_deser(val, sformat="json"): class UserInfoErrorResponse(oauth2.ResponseMessage): c_allowed_values = { - "error": ["invalid_schema", "invalid_request", - "invalid_token", "insufficient_scope"] + "error": ["invalid_schema", "invalid_request", "invalid_token", "insufficient_scope"] } class DiscoveryRequest(Message): - c_param = { - "resource": SINGLE_REQUIRED_STRING, - "rel": SINGLE_REQUIRED_STRING - } + c_param = {"resource": SINGLE_REQUIRED_STRING, "rel": SINGLE_REQUIRED_STRING} class Link(Message): """ https://tools.ietf.org/html/rfc5988 """ + c_param = { "rel": SINGLE_REQUIRED_STRING, "type": SINGLE_OPTIONAL_STRING, "href": SINGLE_OPTIONAL_STRING, "titles": SINGLE_OPTIONAL_DICT, - "properties": SINGLE_OPTIONAL_DICT + "properties": SINGLE_OPTIONAL_DICT, } @@ -1088,7 +1103,7 @@ def link_deser(val, sformat="urlencoded"): def link_ser(inst, sformat, lev=0): if sformat in ["urlencoded", "json"]: if isinstance(inst, dict): - if sformat == 'json': + if sformat == "json": res = json.dumps(inst) else: res = urlencode([(k, v) for k, v in inst.items()]) @@ -1125,19 +1140,17 @@ class JRD(ResponseMessage): """ JSON Resource Descriptor https://tools.ietf.org/html/rfc7033#section-4.4 """ + c_param = { "subject": SINGLE_OPTIONAL_STRING, "aliases": OPTIONAL_LIST_OF_STRINGS, "properties": SINGLE_OPTIONAL_DICT, - "links": REQUIRED_LINKS + "links": REQUIRED_LINKS, } class WebFingerRequest(Message): - c_param = { - "resource": SINGLE_REQUIRED_STRING, - "rel": SINGLE_REQUIRED_STRING - } + c_param = {"resource": SINGLE_REQUIRED_STRING, "rel": SINGLE_REQUIRED_STRING} c_default = {"rel": "http://openid.net/specs/connect/1.0/issuer"} @@ -1151,10 +1164,7 @@ class Claims(Message): class ClaimsRequest(Message): - c_param = { - "userinfo": OPTIONAL_MULTIPLE_Claims, - "id_token": OPTIONAL_MULTIPLE_Claims - } + c_param = {"userinfo": OPTIONAL_MULTIPLE_Claims, "id_token": OPTIONAL_MULTIPLE_Claims} def factory(msgtype, **kwargs): @@ -1170,8 +1180,9 @@ def factory(msgtype, **kwargs): return oauth2.factory(msgtype, **kwargs) -def make_openid_request(arq, keys, issuer, request_object_signing_alg, recv, with_jti=False, - lifetime=0): +def make_openid_request( + arq, keys, issuer, request_object_signing_alg, recv, with_jti=False, lifetime=0 +): """ Construct the JWT to be passed by value (the request parameter) or by reference (request_uri). @@ -1219,7 +1230,7 @@ def claims_match(value, claimspec): elif key == "values": if value in val: matched = True - elif key == 'essential': + elif key == "essential": # Whether it's essential or not doesn't change anything here continue @@ -1227,7 +1238,7 @@ def claims_match(value, claimspec): break # No values to test against so it's just about being there or not - if list(claimspec.keys()) == ['essential']: + if list(claimspec.keys()) == ["essential"]: return True return matched diff --git a/src/oidcmsg/oidc/identity_assurance.py b/src/oidcmsg/oidc/identity_assurance.py index cff6eb2..9a65032 100644 --- a/src/oidcmsg/oidc/identity_assurance.py +++ b/src/oidcmsg/oidc/identity_assurance.py @@ -24,7 +24,7 @@ class PlaceOfBirth(Message): c_param = { "country": SINGLE_REQUIRED_STRING, "region": SINGLE_OPTIONAL_STRING, - "locality": SINGLE_REQUIRED_STRING + "locality": SINGLE_REQUIRED_STRING, } @@ -37,7 +37,7 @@ def place_of_birth_deser(val, sformat="json"): if not isinstance(val, str): val = json.dumps(val) sformat = "json" - elif sformat == 'dict': + elif sformat == "dict": if isinstance(val, str): val = json.loads(val) @@ -150,25 +150,24 @@ def date_deser(val, sformat="", lev=0): class IdentityAssuranceClaims(OpenIDSchema): c_param = OpenIDSchema.c_param.copy() - c_param.update({ - "place_of_birth": SINGLE_OPTIONAL_JSON, - "nationalities": SINGLE_OPTIONAL_STRING, - "birth_family_name": SINGLE_OPTIONAL_STRING, - "birth_given_name": SINGLE_OPTIONAL_STRING, - "birth_middle_name": SINGLE_OPTIONAL_STRING, - "salutation": SINGLE_OPTIONAL_STRING, - "title": SINGLE_OPTIONAL_STRING - }) + c_param.update( + { + "place_of_birth": SINGLE_OPTIONAL_JSON, + "nationalities": SINGLE_OPTIONAL_STRING, + "birth_family_name": SINGLE_OPTIONAL_STRING, + "birth_given_name": SINGLE_OPTIONAL_STRING, + "birth_middle_name": SINGLE_OPTIONAL_STRING, + "salutation": SINGLE_OPTIONAL_STRING, + "title": SINGLE_OPTIONAL_STRING, + } + ) OPTIONAL_IDA_CLAIMS = (IdentityAssuranceClaims, False, msg_ser, msg_deser, False) class Verifier(Message): - c_param = { - "organization": SINGLE_REQUIRED_STRING, - "txn": SINGLE_REQUIRED_STRING - } + c_param = {"organization": SINGLE_REQUIRED_STRING, "txn": SINGLE_REQUIRED_STRING} def verifier_deser(val, sformat="urlencoded"): @@ -185,10 +184,7 @@ def verifier_deser(val, sformat="urlencoded"): class Issuer(Message): - c_param = { - "name": SINGLE_REQUIRED_STRING, - "country": SINGLE_REQUIRED_STRING - } + c_param = {"name": SINGLE_REQUIRED_STRING, "country": SINGLE_REQUIRED_STRING} def issuer_deser(val, sformat="urlencoded"): @@ -210,7 +206,7 @@ class Document(Message): "number": SINGLE_REQUIRED_STRING, "issuer": REQUIRED_ISSUER, "date_of_issuance": REQURIED_TIME_STAMP, - "date_of_expiry": REQURIED_TIME_STAMP + "date_of_expiry": REQURIED_TIME_STAMP, } @@ -222,9 +218,7 @@ def document_deser(val, sformat="urlencoded"): class Evidence(Message): - c_param = { - "type": SINGLE_OPTIONAL_STRING - } + c_param = {"type": SINGLE_OPTIONAL_STRING} def verify(self, **kwargs): _type = self.get("type") @@ -274,12 +268,14 @@ def evidence_list_deser(val, sformat="urlencoded", lev=0): class IdDocument(Evidence): c_param = Evidence.c_param.copy() - c_param.update({ - "method": SINGLE_REQUIRED_STRING, - "verifier": REQUIRED_VERIFIER, - "time": OPTIONAL_TIME_STAMP, - "document": OPTIONAL_DOCUMENT - }) + c_param.update( + { + "method": SINGLE_REQUIRED_STRING, + "verifier": REQUIRED_VERIFIER, + "time": OPTIONAL_TIME_STAMP, + "document": OPTIONAL_DOCUMENT, + } + ) def id_document_deser(val, sformat="urlencoded"): @@ -298,9 +294,11 @@ def id_document_deser(val, sformat="urlencoded"): class Provider(AddressClaim): c_param = AddressClaim.c_param.copy() - c_param.update({ - "name": SINGLE_OPTIONAL_STRING, - }) + c_param.update( + { + "name": SINGLE_OPTIONAL_STRING, + } + ) def provider_deser(val, sformat="urlencoded"): @@ -318,10 +316,7 @@ def provider_deser(val, sformat="urlencoded"): class UtilityBill(Evidence): c_param = Evidence.c_param.copy() - c_param.update({ - "provider": REQUIRED_PROVIDER, - "date": OPTIONAL_TIME_STAMP - }) + c_param.update({"provider": REQUIRED_PROVIDER, "date": OPTIONAL_TIME_STAMP}) def utility_bill_deser(val, sformat="urlencoded"): @@ -340,11 +335,13 @@ def utility_bill_deser(val, sformat="urlencoded"): class QES(Evidence): c_param = Evidence.c_param.copy() - c_param.update({ - "issuer": SINGLE_REQUIRED_STRING, - "serial_number": SINGLE_REQUIRED_STRING, - "created_at": REQURIED_TIME_STAMP - }) + c_param.update( + { + "issuer": SINGLE_REQUIRED_STRING, + "serial_number": SINGLE_REQUIRED_STRING, + "created_at": REQURIED_TIME_STAMP, + } + ) def qes_deser(val, sformat="urlencoded"): @@ -388,14 +385,16 @@ def verification_element_deser(val, sformat="urlencoded"): OPTIONAL_VERIFICATION_ELEMENT = ( - VerificationElement, False, msg_ser, verification_element_deser, False) + VerificationElement, + False, + msg_ser, + verification_element_deser, + False, +) class VerifiedClaims(Message): - c_param = { - "verification": OPTIONAL_VERIFICATION_ELEMENT, - "claims": OPTIONAL_IDA_CLAIMS - } + c_param = {"verification": OPTIONAL_VERIFICATION_ELEMENT, "claims": OPTIONAL_IDA_CLAIMS} SINGLE_OPTIONAL_CLAIMSREQ = (ClaimsRequest, False, msg_ser_json, claims_request_deser, False) @@ -488,14 +487,16 @@ def verification_element_request_deser(val, sformat="urlencoded"): OPTIONAL_VERIFICATION_ELEMENT_REQUEST = ( - VerificationElementRequest, False, msg_ser, verification_element_request_deser, True) + VerificationElementRequest, + False, + msg_ser, + verification_element_request_deser, + True, +) class VerifiedClaimsRequest(Message): - c_param = { - "verification": OPTIONAL_MESSAGE, - "claims": OPTIONAL_IDA_CLAIMS - } + c_param = {"verification": OPTIONAL_MESSAGE, "claims": OPTIONAL_IDA_CLAIMS} def verify(self, **kwargs): super(VerifiedClaimsRequest, self).verify(**kwargs) @@ -536,7 +537,9 @@ def __setitem__(self, key, value): if _value_type: if isinstance(value, ClaimsConstructor): if not isinstance(value.base_class, _value_type): - raise ValueError("Wrong type of value '{}':'{}'".format(key, type(value.base_class))) + raise ValueError( + "Wrong type of value '{}':'{}'".format(key, type(value.base_class)) + ) elif not _correct_value_type(value, _value_type): raise ValueError("Wrong type of value '{}':'{}'".format(key, type(value))) @@ -552,4 +555,4 @@ def to_dict(self): return res def to_json(self): - return json.dumps(self.to_dict()) \ No newline at end of file + return json.dumps(self.to_dict()) diff --git a/src/oidcmsg/oidc/session.py b/src/oidcmsg/oidc/session.py index 58d2c1c..21b0676 100644 --- a/src/oidcmsg/oidc/session.py +++ b/src/oidcmsg/oidc/session.py @@ -27,10 +27,7 @@ class RefreshSessionRequest(MessageWithIdToken): c_param = MessageWithIdToken.c_param.copy() - c_param.update({ - "redirect_url": SINGLE_REQUIRED_STRING, - "state": SINGLE_REQUIRED_STRING - }) + c_param.update({"redirect_url": SINGLE_REQUIRED_STRING, "state": SINGLE_REQUIRED_STRING}) class RefreshSessionResponse(MessageWithIdToken, ResponseMessage): @@ -52,14 +49,14 @@ class EndSessionRequest(Message): "id_token_hint": SINGLE_OPTIONAL_IDTOKEN, "post_logout_redirect_uri": SINGLE_OPTIONAL_STRING, "state": SINGLE_OPTIONAL_STRING, - "ui_locales": OPTIONAL_LIST_OF_SP_SEP_STRINGS + "ui_locales": OPTIONAL_LIST_OF_SP_SEP_STRINGS, } def verify(self, **kwargs): super(EndSessionRequest, self).verify(**kwargs) clear_verified_claims(self) - if 'post_logout_redirect_uri' in self: + if "post_logout_redirect_uri" in self: if "id_token_hint" not in self: return False @@ -72,7 +69,7 @@ def verify(self, **kwargs): except KeyError: pass idt = IdToken().from_jwt(str(self["id_token_hint"]), **args) - if not verify_id_token(self, claim='id_token_hint', **kwargs): + if not verify_id_token(self, claim="id_token_hint", **kwargs): return False # Add the verified ID Token to the message instance self[verified_claim_name("id_token_hint")] = idt @@ -90,67 +87,68 @@ class LogoutToken(Message): Defined in https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken """ + c_param = { "iss": SINGLE_REQUIRED_STRING, "sub": SINGLE_OPTIONAL_STRING, "aud": REQUIRED_LIST_OF_STRINGS, # Array of strings or string "iat": SINGLE_REQUIRED_INT, "jti": SINGLE_REQUIRED_STRING, - 'events': SINGLE_REQUIRED_JSON, - 'sid': SINGLE_OPTIONAL_STRING + "events": SINGLE_REQUIRED_JSON, + "sid": SINGLE_OPTIONAL_STRING, } def verify(self, **kwargs): super(LogoutToken, self).verify(**kwargs) - if 'nonce' in self: - raise MessageException('"nonce" is prohibited from appearing in ' - 'a LogoutToken.') + if "nonce" in self: + raise MessageException('"nonce" is prohibited from appearing in ' "a LogoutToken.") # Check the 'events' JSON - _keys = list(self['events'].keys()) + _keys = list(self["events"].keys()) if len(_keys) != 1: raise ValueError('Must only be one member in "events"') if _keys[0] != "http://schemas.openid.net/event/backchannel-logout": raise ValueError('Wrong member in "events"') - if self['events'][_keys[0]] != {}: + if self["events"][_keys[0]] != {}: raise ValueError('Wrong member value in "events"') # There must be either a 'sub' or a 'sid', and may contain both - if not ('sub' in self or 'sid' in self): + if not ("sub" in self or "sid" in self): raise ValueError('There MUST be either a "sub" or a "sid"') try: - if kwargs['aud'] not in self['aud']: - raise NotForMe('Not among intended audience') + if kwargs["aud"] not in self["aud"]: + raise NotForMe("Not among intended audience") except KeyError: pass try: - if kwargs['iss'] != self['iss']: - raise NotForMe('Wrong issuer') + if kwargs["iss"] != self["iss"]: + raise NotForMe("Wrong issuer") except KeyError: pass _now = utc_time_sans_frac() try: - _skew = kwargs['skew'] + _skew = kwargs["skew"] except KeyError: _skew = 0 try: - _exp = self['iat'] + _exp = self["iat"] except KeyError: pass else: - if self['iat'] > (_now + _skew): - raise ValueError('Invalid issued_at time') + if self["iat"] > (_now + _skew): + raise ValueError("Invalid issued_at time") _allowed = kwargs.get("allowed_sign_alg") - if _allowed and self.jws_header['alg'] != _allowed: + if _allowed and self.jws_header["alg"] != _allowed: _msg = "Wrong token signing algorithm, {} != {}".format( - self.jws_header['alg'], kwargs["allowed_sign_alg"]) + self.jws_header["alg"], kwargs["allowed_sign_alg"] + ) raise UnsupportedAlgorithm(_msg) return True @@ -165,9 +163,7 @@ class BackChannelLogoutRequest(Message): https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken """ - c_param = { - "logout_token": SINGLE_REQUIRED_STRING - } + c_param = {"logout_token": SINGLE_REQUIRED_STRING} def verify(self, **kwargs): super(BackChannelLogoutRequest, self).verify(**kwargs) @@ -183,6 +179,6 @@ def verify(self, **kwargs): return False self[verified_claim_name("logout_token")] = idt - logger.info('Verified ID Token: {}'.format(idt.to_dict())) + logger.info("Verified ID Token: {}".format(idt.to_dict())) return True diff --git a/src/oidcmsg/storage/__init__.py b/src/oidcmsg/storage/__init__.py index eda66cd..fa9daea 100644 --- a/src/oidcmsg/storage/__init__.py +++ b/src/oidcmsg/storage/__init__.py @@ -48,4 +48,3 @@ def synch(self): def keys(self): raise NotImplemented() - diff --git a/src/oidcmsg/storage/abfile.py b/src/oidcmsg/storage/abfile.py index d35beee..0d15e49 100644 --- a/src/oidcmsg/storage/abfile.py +++ b/src/oidcmsg/storage/abfile.py @@ -24,7 +24,10 @@ class AbstractFileSystem(Storage): Not directories in directories. """ - def __init__(self, conf_dict, ): + def __init__( + self, + conf_dict, + ): """ items = FileSystem( { @@ -46,7 +49,7 @@ def __init__(self, conf_dict, ): super().__init__(conf_dict) self.config = conf_dict - _fdir = conf_dict.get('fdir', '.') + _fdir = conf_dict.get("fdir", ".") if "issuer" in conf_dict: _fdir = os.path.join(_fdir, quote_plus(conf_dict["issuer"])) @@ -54,13 +57,13 @@ def __init__(self, conf_dict, ): self.fmtime = {} self.storage = {} - key_conv = conf_dict.get('key_conv') + key_conv = conf_dict.get("key_conv") if key_conv: self.key_conv = importer(key_conv)() else: self.key_conv = QPKey() - value_conv = conf_dict.get('value_conv') + value_conv = conf_dict.get("value_conv") if value_conv: self.value_conv = importer(value_conv)() else: @@ -71,7 +74,7 @@ def __init__(self, conf_dict, ): self.synch() - def get(self, item, default=None ): + def get(self, item, default=None): try: return self[item] except KeyError: @@ -87,7 +90,7 @@ def __getitem__(self, item): item = self.key_conv.serialize(item) fname = os.path.join(self.fdir, item) if self._is_file(fname): - lock = FileLock('{}.lock'.format(fname)) + lock = FileLock("{}.lock".format(fname)) with lock: if self.is_changed(item, fname): logger.info("File content change in {}".format(item)) @@ -118,9 +121,9 @@ def __setitem__(self, key, value): _key = key fname = os.path.join(self.fdir, _key) - lock = FileLock('{}.lock'.format(fname)) + lock = FileLock("{}.lock".format(fname)) with lock: - with open(fname, 'w') as fp: + with open(fname, "w") as fp: fp.write(self.value_conv.serialize(value)) self.storage[_key] = value @@ -130,7 +133,7 @@ def __setitem__(self, key, value): def __delitem__(self, key): fname = os.path.join(self.fdir, key) if os.path.isfile(fname): - lock = FileLock('{}.lock'.format(fname)) + lock = FileLock("{}.lock".format(fname)) with lock: os.unlink(fname) @@ -193,15 +196,15 @@ def is_changed(self, item, fname): def _read_info(self, fname): if os.path.isfile(fname): try: - lock = FileLock('{}.lock'.format(fname)) + lock = FileLock("{}.lock".format(fname)) with lock: - info = open(fname, 'r').read().strip() + info = open(fname, "r").read().strip() return self.value_conv.deserialize(info) except Exception as err: logger.error(err) raise else: - logger.error('No such file: {}'.format(fname)) + logger.error("No such file: {}".format(fname)) return None def synch(self): @@ -217,7 +220,7 @@ def synch(self): if not os.path.isfile(fname): continue - if fname.endswith('.lock'): + if fname.endswith(".lock"): continue if f in self.fmtime: @@ -228,7 +231,7 @@ def synch(self): try: self.storage[f] = self._read_info(fname) except Exception as err: - logger.warning('Bad content in {} ({})'.format(fname, err)) + logger.warning("Bad content in {} ({})".format(fname, err)) else: self.fmtime[f] = mtime @@ -281,27 +284,28 @@ def __len__(self): if not os.path.isfile(fname): continue - if fname.endswith('.lock'): + if fname.endswith(".lock"): continue n += 1 return n def __str__(self): - return '{config:' + str(self.config) + ', info:' + str(self.storage) + '}' + return "{config:" + str(self.config) + ", info:" + str(self.storage) + "}" + class LabeledAbstractFileSystem(Storage): - def __init__(self, conf_dict, label=''): - _conf = {k: v for k, v in conf_dict.items() if k != 'label'} + def __init__(self, conf_dict, label=""): + _conf = {k: v for k, v in conf_dict.items() if k != "label"} Storage.__init__(self, conf_dict=_conf) self.storage = AbstractFileSystem(conf_dict) - _label = label or conf_dict.get('label', '') + _label = label or conf_dict.get("label", "") if not _label: - self.label = '' + self.label = "" else: - self.label = '__{}__'.format(_label) + self.label = "__{}__".format(_label) @key_label def get(self, k, default=None): @@ -330,15 +334,15 @@ def __contains__(self, k): def __iter__(self): for key, val in self.storage.__iter__(): if key.startswith(self.label): - yield key[len(self.label):], val + yield key[len(self.label) :], val def keys(self): - return [k[len(self.label):] for k in self.storage.keys() if k.startswith(self.label)] + return [k[len(self.label) :] for k in self.storage.keys() if k.startswith(self.label)] def items(self): for key, val in self.storage.__iter__(): if key.startswith(self.label): - yield key[len(self.label):], val + yield key[len(self.label) :], val def __len__(self): if not os.path.isdir(self.storage.fdir): @@ -351,7 +355,7 @@ def __len__(self): if not os.path.isfile(fname): continue - if fname.endswith('.lock'): + if fname.endswith(".lock"): continue n += 1 diff --git a/src/oidcmsg/storage/absqlalchemy.py b/src/oidcmsg/storage/absqlalchemy.py index 30c0179..96eeb48 100644 --- a/src/oidcmsg/storage/absqlalchemy.py +++ b/src/oidcmsg/storage/absqlalchemy.py @@ -7,14 +7,15 @@ PlainDict = dict + class AbstractStorageSQLAlchemy: def __init__(self, conf_dict): - self.engine = alchemy_db.create_engine(conf_dict['url']) + self.engine = alchemy_db.create_engine(conf_dict["url"]) self.connection = self.engine.connect() self.metadata = alchemy_db.MetaData() - self.table = alchemy_db.Table(conf_dict['params']['table'], - self.metadata, autoload=True, - autoload_with=self.engine) + self.table = alchemy_db.Table( + conf_dict["params"]["table"], self.metadata, autoload=True, autoload_with=self.engine + ) Session = sessionmaker(bind=self.engine) self.session = scoped_session(Session) @@ -26,20 +27,17 @@ def get(self, k): def set(self, k, v): self.delete(k) - ins = self.table.insert().values(owner=k, - data=v) + ins = self.table.insert().values(owner=k, data=v) self.session.execute(ins) self.session.commit() return 1 def update(self, k, v): """ - k = value_to_match - v = value_to_be_substituted + k = value_to_match + v = value_to_be_substituted """ - upquery = self.table.update(). \ - where(self.table.c.owner == k). \ - values(**{'data': v}) + upquery = self.table.update().where(self.table.c.owner == k).values(**{"data": v}) self.session.execute(upquery) self.session.commit() return 1 diff --git a/src/oidcmsg/storage/db_setup.py b/src/oidcmsg/storage/db_setup.py index d4ecd7e..023bc65 100644 --- a/src/oidcmsg/storage/db_setup.py +++ b/src/oidcmsg/storage/db_setup.py @@ -7,23 +7,20 @@ class Thing(Base): - __tablename__ = 'thing' - - id = alchemy_db.Column(alchemy_db.Integer, - alchemy_db.Sequence('thing_id_seq'), - primary_key=True) - owner = alchemy_db.Column(alchemy_db.String(80), - unique=False, nullable=False) - data = alchemy_db.Column(alchemy_db.String(4096), - unique=False, nullable=False) - created = alchemy_db.Column(alchemy_db.DateTime, - default=datetime.datetime.utcnow) + __tablename__ = "thing" + + id = alchemy_db.Column( + alchemy_db.Integer, alchemy_db.Sequence("thing_id_seq"), primary_key=True + ) + owner = alchemy_db.Column(alchemy_db.String(80), unique=False, nullable=False) + data = alchemy_db.Column(alchemy_db.String(4096), unique=False, nullable=False) + created = alchemy_db.Column(alchemy_db.DateTime, default=datetime.datetime.utcnow) def __repr__(self): - return '' % self.owner + return "" % self.owner def create_database(conf_dict): - engine = alchemy_db.create_engine(conf_dict['url']) + engine = alchemy_db.create_engine(conf_dict["url"]) connection = engine.connect() Base.metadata.create_all(engine) diff --git a/src/oidcmsg/storage/extension.py b/src/oidcmsg/storage/extension.py index 5554688..d6f3d4a 100644 --- a/src/oidcmsg/storage/extension.py +++ b/src/oidcmsg/storage/extension.py @@ -20,15 +20,15 @@ def delete(self, item): del self[item] -class LabeledDict(): - def __init__(self, label=''): +class LabeledDict: + def __init__(self, label=""): self.storage = SetGetDict() - if label == '': + if label == "": self.label = label self.label_len = 0 else: - self.label = '__{}__'.format(label) + self.label = "__{}__".format(label) self.label_len = len(self.label) @key_label @@ -62,7 +62,7 @@ def __contains__(self, k): def __iter__(self): for key, val in self.storage.__iter__(): if key.startswith(self.label): - yield key[self.label_len:], val + yield key[self.label_len :], val def keys(self): - return [k[self.label_len:] for k in self.storage.keys() if k.startswith(self.label)] + return [k[self.label_len :] for k in self.storage.keys() if k.startswith(self.label)] diff --git a/src/oidcmsg/storage/init.py b/src/oidcmsg/storage/init.py index 08ac3e5..05bdaea 100644 --- a/src/oidcmsg/storage/init.py +++ b/src/oidcmsg/storage/init.py @@ -25,14 +25,14 @@ class ConfigurationError(Exception): pass -def get_storage_conf(db_conf=None, typ='default'): +def get_storage_conf(db_conf=None, typ="default"): _conf = None if db_conf: _conf = db_conf.get(typ) if _conf: return _conf - elif typ != 'default': - _conf = db_conf.get('default') + elif typ != "default": + _conf = db_conf.get("default") else: raise ConfigurationError() @@ -40,16 +40,16 @@ def get_storage_conf(db_conf=None, typ='default'): def storage_factory(configuration): - _handler = configuration.get('handler') + _handler = configuration.get("handler") if _handler: storage_cls = importer(_handler) else: - raise ConfigurationError('Missing handler specification') - _conf = {k: v for k, v in configuration.items() if k != 'handler'} + raise ConfigurationError("Missing handler specification") + _conf = {k: v for k, v in configuration.items() if k != "handler"} return storage_cls(_conf) -def init_storage(db_conf=None, key='default'): +def init_storage(db_conf=None, key="default"): """ Returns a storage instance. @@ -62,4 +62,4 @@ def init_storage(db_conf=None, key='default'): if _conf: return storage_factory(_conf) - return LabeledDict({'label': key}) + return LabeledDict({"label": key}) diff --git a/src/oidcmsg/storage/utils.py b/src/oidcmsg/storage/utils.py index c8ed4b6..d7525c3 100644 --- a/src/oidcmsg/storage/utils.py +++ b/src/oidcmsg/storage/utils.py @@ -3,17 +3,17 @@ def modsplit(name): """Split importable""" - if ':' in name: - _part = name.split(':') + if ":" in name: + _part = name.split(":") if len(_part) != 2: raise ValueError("Syntax error: {s}") return _part[0], _part[1] - _part = name.split('.') + _part = name.split(".") if len(_part) < 2: raise ValueError("Syntax error: {s}") - return '.'.join(_part[:-1]), _part[-1] + return ".".join(_part[:-1]), _part[-1] def importer(name): diff --git a/src/oidcmsg/time_util.py b/src/oidcmsg/time_util.py index 093db1c..7d28095 100755 --- a/src/oidcmsg/time_util.py +++ b/src/oidcmsg/time_util.py @@ -27,8 +27,7 @@ from datetime import timedelta TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" -TIME_FORMAT_WITH_FRAGMENT = re.compile( - "^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") +TIME_FORMAT_WITH_FRAGMENT = re.compile("^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") class TimeUtilError(Exception): @@ -70,18 +69,18 @@ def maximum_day_in_month_for(year, month): ("T", None), ("H", "tm_hour"), ("M", "tm_min"), - ("S", "tm_sec") + ("S", "tm_sec"), ] def parse_duration(duration): # (-)PnYnMnDTnHnMnS index = 0 - if duration[0] == '-': - sign = '-' + if duration[0] == "-": + sign = "-" index += 1 else: - sign = '+' + sign = "+" if duration[index] != "P": raise ValueError('duration index {} != "P"'.format(duration[index])) index += 1 @@ -90,7 +89,7 @@ def parse_duration(duration): for code, typ in D_FORMAT: # print duration[index:], code - if duration[index] == '-': + if duration[index] == "-": raise TimeUtilError("Negation not allowed on individual items") if code == "T": if duration[index] == "T": @@ -103,16 +102,15 @@ def parse_duration(duration): try: mod = duration[index:].index(code) try: - dic[typ] = int(duration[index:index + mod]) + dic[typ] = int(duration[index : index + mod]) except ValueError: if code == "S": try: - dic[typ] = float(duration[index:index + mod]) + dic[typ] = float(duration[index : index + mod]) except ValueError: raise TimeUtilError("Not a float") else: - raise TimeUtilError( - "Fractions not allow on anything byt seconds") + raise TimeUtilError("Fractions not allow on anything byt seconds") index = mod + index + 1 except ValueError: dic[typ] = 0 @@ -126,7 +124,7 @@ def parse_duration(duration): def add_duration(tid, duration): (sign, dur) = parse_duration(duration) - if sign == '+': + if sign == "+": # Months temp = tid.tm_mon + dur["tm_mon"] month = modulo(temp, 1, 13) @@ -165,8 +163,7 @@ def add_duration(tid, duration): month = modulo(temp, 1, 13) year += f_quotient(temp, 1, 13) - return time.localtime(time.mktime((year, month, days, hour, minutes, - secs, 0, 0, -1))) + return time.localtime(time.mktime((year, month, days, hour, minutes, secs, 0, 0, -1))) else: pass @@ -174,23 +171,22 @@ def add_duration(tid, duration): # --------------------------------------------------------------------------- -def time_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0): +def time_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0): """ Will return a time specification for a time sometime in the future. :return: datetime instance using UTC time """ - delta = timedelta(days, seconds, microseconds, milliseconds, - minutes, hours, weeks) + delta = timedelta(days, seconds, microseconds, milliseconds, minutes, hours, weeks) return datetime.utcnow() + delta -def time_a_while_ago(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0): +def time_a_while_ago( + days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 +): """ Will return a time specification for a time sometime in the past. - + :param days: :param seconds: :param microseconds: @@ -200,13 +196,20 @@ def time_a_while_ago(days=0, seconds=0, microseconds=0, milliseconds=0, :param weeks: :return: datetime instance using UTC time """ - delta = timedelta(days, seconds, microseconds, milliseconds, - minutes, hours, weeks) + delta = timedelta(days, seconds, microseconds, milliseconds, minutes, hours, weeks) return datetime.utcnow() - delta -def in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0, time_format=TIME_FORMAT): +def in_a_while( + days=0, + seconds=0, + microseconds=0, + milliseconds=0, + minutes=0, + hours=0, + weeks=0, + time_format=TIME_FORMAT, +): """ :param days: :param seconds: @@ -221,12 +224,21 @@ def in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, if not time_format: time_format = TIME_FORMAT - return time_in_a_while(days, seconds, microseconds, milliseconds, - minutes, hours, weeks).strftime(time_format) - - -def a_while_ago(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0, time_format=TIME_FORMAT): + return time_in_a_while( + days, seconds, microseconds, milliseconds, minutes, hours, weeks + ).strftime(time_format) + + +def a_while_ago( + days=0, + seconds=0, + microseconds=0, + milliseconds=0, + minutes=0, + hours=0, + weeks=0, + time_format=TIME_FORMAT, +): """ :param days: @@ -239,15 +251,16 @@ def a_while_ago(days=0, seconds=0, microseconds=0, milliseconds=0, :param time_format: :return: Formatet string """ - return time_a_while_ago(days, seconds, microseconds, milliseconds, - minutes, hours, weeks).strftime(time_format) + return time_a_while_ago( + days, seconds, microseconds, milliseconds, minutes, hours, weeks + ).strftime(time_format) # --------------------------------------------------------------------------- def shift_time(dtime, shift): - """ Adds/deletes an integer amount of seconds from a datetime specification + """Adds/deletes an integer amount of seconds from a datetime specification :param dtime: The datatime specification :param shift: The wanted time shift (+/-) @@ -274,7 +287,7 @@ def str_to_time(timestr, time_format=TIME_FORMAT): try: elem = TIME_FORMAT_WITH_FRAGMENT.match(timestr) except Exception as exc: - print >> sys.stderr, "Exception: %s on %s" % (exc, timestr) + print >>sys.stderr, "Exception: %s on %s" % (exc, timestr) raise then = time.strptime(elem.groups()[0] + "Z", TIME_FORMAT) @@ -342,8 +355,9 @@ def time_sans_frac(): return int("%d" % time.time()) -def epoch_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0): +def epoch_in_a_while( + days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 +): """ Return the number of seconds since epoch a while from now. @@ -357,6 +371,5 @@ def epoch_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, :return: Seconds since epoch (1970-01-01) """ - dt = time_in_a_while(days, seconds, microseconds, milliseconds, minutes, - hours, weeks) + dt = time_in_a_while(days, seconds, microseconds, milliseconds, minutes, hours, weeks) return int((dt - datetime(1970, 1, 1)).total_seconds()) diff --git a/tests/test_03_time_util.py b/tests/test_03_time_util.py index 3638170..a053676 100755 --- a/tests/test_03_time_util.py +++ b/tests/test_03_time_util.py @@ -27,7 +27,7 @@ from oidcmsg.time_util import utc_time_sans_frac from oidcmsg.time_util import valid -__author__ = 'rohe0002' +__author__ = "rohe0002" def test_f_quotient(): @@ -69,12 +69,12 @@ def test_modulo_2(): def test_parse_duration(): (sign, d) = parse_duration("P1Y3M5DT7H10M3.3S") assert sign == "+" - assert d['tm_sec'] == 3.3 - assert d['tm_mon'] == 3 - assert d['tm_hour'] == 7 - assert d['tm_mday'] == 5 - assert d['tm_year'] == 1 - assert d['tm_min'] == 10 + assert d["tm_sec"] == 3.3 + assert d["tm_mon"] == 3 + assert d["tm_hour"] == 7 + assert d["tm_mday"] == 5 + assert d["tm_year"] == 1 + assert d["tm_min"] == 10 def test_add_duration_1(): @@ -174,22 +174,25 @@ def test_not_on_or_after(): def test_parse_duration_1(): (sign, d) = parse_duration("-P1Y3M5DT7H10M3.3S") assert sign == "-" - assert d['tm_sec'] == 3.3 - assert d['tm_mon'] == 3 - assert d['tm_hour'] == 7 - assert d['tm_mday'] == 5 - assert d['tm_year'] == 1 - assert d['tm_min'] == 10 - - -@pytest.mark.parametrize("duration", [ - "-P1Y-3M5DT7H10M3.3S", - "-P1Y3M5DU7H10M3.3S", - "-P1Y3M5DT", - "-P1Y3M5DU7H10M3.S", - "-P1Y3M5DT7H10MxS", - "-P1Y4M4DT7H10.5M3S" -]) + assert d["tm_sec"] == 3.3 + assert d["tm_mon"] == 3 + assert d["tm_hour"] == 7 + assert d["tm_mday"] == 5 + assert d["tm_year"] == 1 + assert d["tm_min"] == 10 + + +@pytest.mark.parametrize( + "duration", + [ + "-P1Y-3M5DT7H10M3.3S", + "-P1Y3M5DU7H10M3.3S", + "-P1Y3M5DT", + "-P1Y3M5DU7H10M3.S", + "-P1Y3M5DT7H10MxS", + "-P1Y4M4DT7H10.5M3S", + ], +) def test_parse_duration_error(duration): with pytest.raises(TimeUtilError): parse_duration(duration) diff --git a/tests/test_04_message.py b/tests/test_04_message.py index 3940904..259af93 100755 --- a/tests/test_04_message.py +++ b/tests/test_04_message.py @@ -28,7 +28,7 @@ from oidcmsg.message import sp_sep_list_deserializer from oidcmsg.oauth2 import Message -__author__ = 'Roland Hedberg' +__author__ = "Roland Hedberg" from oidcmsg.oauth2 import ResponseMessage @@ -48,14 +48,14 @@ KEYJAR = build_keyjar(keys) IKEYJAR = build_keyjar(keys) -IKEYJAR.import_jwks(IKEYJAR.export_jwks(private=True), 'issuer') -del IKEYJAR[''] +IKEYJAR.import_jwks(IKEYJAR.export_jwks(private=True), "issuer") +del IKEYJAR[""] KEYJARS = {} -for iss in ['A', 'B', 'C']: +for iss in ["A", "B", "C"]: _kj = build_keyjar(keym) _kj.import_jwks(_kj.export_jwks(private=True), iss) - del _kj[''] + del _kj[""] KEYJARS[iss] = _kj @@ -116,165 +116,188 @@ class DummyMessage(Message): "opt_int": SINGLE_OPTIONAL_INT, "opt_str_list": OPTIONAL_LIST_OF_STRINGS, "req_str_list": REQUIRED_LIST_OF_STRINGS, - "opt_json": SINGLE_OPTIONAL_JSON + "opt_json": SINGLE_OPTIONAL_JSON, } class TestMessage(object): def test_json_serialization(self): - item = DummyMessage(req_str="Fair", opt_str="game", opt_int=9, - opt_str_list=["one", "two"], - req_str_list=["spike", "lee"], - opt_json='{"ford": "green"}') + item = DummyMessage( + req_str="Fair", + opt_str="game", + opt_int=9, + opt_str_list=["one", "two"], + req_str_list=["spike", "lee"], + opt_json='{"ford": "green"}', + ) jso = item.serialize(method="json") item2 = DummyMessage().deserialize(jso, "json") - assert _eq(item2.keys(), - ['opt_str', 'req_str', 'opt_json', 'req_str_list', - 'opt_str_list', 'opt_int']) + assert _eq( + item2.keys(), + ["opt_str", "req_str", "opt_json", "req_str_list", "opt_str_list", "opt_int"], + ) def test_from_dict(self): - _dict = { - "req_str": "Fair", "req_str_list": ["spike", "lee"], - "opt_int": 9 - } + _dict = {"req_str": "Fair", "req_str_list": ["spike", "lee"], "opt_int": 9} _msg = DummyMessage() _msg.from_dict(_dict) assert set(_msg.keys()) == set(_dict.keys()) def test_from_dict_lang_tag_unknown_key(self): _dict = { - "req_str": "Fair", "req_str_list": ["spike", "lee"], - "opt_int": 9, 'attribute#se': 'value' + "req_str": "Fair", + "req_str_list": ["spike", "lee"], + "opt_int": 9, + "attribute#se": "value", } _msg = DummyMessage() _msg.from_dict(_dict) assert set(_msg.keys()) == set(_dict.keys()) def test_from_dict_lang_tag(self): - _dict = { - "req_str#se": "Fair", "req_str_list": ["spike", "lee"], - "opt_int": 9 - } + _dict = {"req_str#se": "Fair", "req_str_list": ["spike", "lee"], "opt_int": 9} _msg = DummyMessage() _msg.from_dict(_dict) assert set(_msg.keys()) == set(_dict.keys()) def test_from_json(self): - jso = '{"req_str": "Fair", "req_str_list": ["spike", "lee"], ' \ - '"opt_int": 9}' + jso = '{"req_str": "Fair", "req_str_list": ["spike", "lee"], ' '"opt_int": 9}' item = DummyMessage().deserialize(jso, "json") - assert _eq(item.keys(), ['req_str', 'req_str_list', 'opt_int']) + assert _eq(item.keys(), ["req_str", "req_str_list", "opt_int"]) assert item["opt_int"] == 9 def test_single_optional(self): - jso = '{"req_str": "Fair", "req_str_list": ["spike", "lee"], ' \ - '"opt_int": [9, 10]}' + jso = '{"req_str": "Fair", "req_str_list": ["spike", "lee"], ' '"opt_int": [9, 10]}' with pytest.raises(ValueError): DummyMessage().deserialize(jso, "json") def test_extra_param(self): - jso = '{"req_str": "Fair", "req_str_list": ["spike", "lee"], "extra": ' \ - '' \ - '' \ - '"out"}' + jso = '{"req_str": "Fair", "req_str_list": ["spike", "lee"], "extra": ' "" "" '"out"}' item = DummyMessage().deserialize(jso, "json") - assert _eq(item.keys(), ['req_str', 'req_str_list', 'extra']) + assert _eq(item.keys(), ["req_str", "req_str_list", "extra"]) assert item["extra"] == "out" def test_to_from_jwt(self): - item = DummyMessage(req_str="Fair", opt_str="game", opt_int=9, - opt_str_list=["one", "two"], - req_str_list=["spike", "lee"], - opt_json='{"ford": "green"}') + item = DummyMessage( + req_str="Fair", + opt_str="game", + opt_int=9, + opt_str_list=["one", "two"], + req_str_list=["spike", "lee"], + opt_json='{"ford": "green"}', + ) keyjar = KeyJar() - keyjar.add_symmetric('', b"A1B2C3D4E5F6G7H8") - jws = item.to_jwt(key=keyjar.get_signing_key('oct'), - algorithm="HS256") + keyjar.add_symmetric("", b"A1B2C3D4E5F6G7H8") + jws = item.to_jwt(key=keyjar.get_signing_key("oct"), algorithm="HS256") jitem = DummyMessage().from_jwt(jws, keyjar) - assert _eq(jitem.keys(), ['opt_str', 'req_str', 'opt_json', - 'req_str_list', 'opt_str_list', 'opt_int']) + assert _eq( + jitem.keys(), + ["opt_str", "req_str", "opt_json", "req_str_list", "opt_str_list", "opt_int"], + ) def test_to_from_jwe(self): - msg = DummyMessage(req_str="Fair", opt_str="game", opt_int=9, - opt_str_list=["one", "two"], - req_str_list=["spike", "lee"], - opt_json='{"ford": "green"}') + msg = DummyMessage( + req_str="Fair", + opt_str="game", + opt_int=9, + opt_str_list=["one", "two"], + req_str_list=["spike", "lee"], + opt_json='{"ford": "green"}', + ) keys = [SYMKey(key="A1B2C3D4E5F6G7H8")] jwe = msg.to_jwe(keys, alg="A128KW", enc="A128CBC-HS256") jitem = DummyMessage().from_jwe(jwe, keys=keys) - assert _eq(jitem.keys(), ['opt_str', 'req_str', 'opt_json', - 'req_str_list', 'opt_str_list', 'opt_int']) + assert _eq( + jitem.keys(), + ["opt_str", "req_str", "opt_json", "req_str_list", "opt_str_list", "opt_int"], + ) def test_to_jwe_from_jwt(self): - msg = DummyMessage(req_str="Fair", opt_str="game", opt_int=9, - opt_str_list=["one", "two"], - req_str_list=["spike", "lee"], - opt_json='{"ford": "green"}') + msg = DummyMessage( + req_str="Fair", + opt_str="game", + opt_int=9, + opt_str_list=["one", "two"], + req_str_list=["spike", "lee"], + opt_json='{"ford": "green"}', + ) keys = [SYMKey(key="A1B2C3D4E5F6G7H8")] jwe = msg.to_jwe(keys, alg="A128KW", enc="A128CBC-HS256") keyjar = KeyJar() - keyjar.add_symmetric('', 'A1B2C3D4E5F6G7H8') + keyjar.add_symmetric("", "A1B2C3D4E5F6G7H8") jitem = DummyMessage().from_jwt(jwe, keyjar) - assert _eq(jitem.keys(), ['opt_str', 'req_str', 'opt_json', - 'req_str_list', 'opt_str_list', 'opt_int']) + assert _eq( + jitem.keys(), + ["opt_str", "req_str", "opt_json", "req_str_list", "opt_str_list", "opt_int"], + ) def test_verify(self): _dict = { - "req_str": "Fair", "opt_str": "game", "opt_int": 9, + "req_str": "Fair", + "opt_str": "game", + "opt_int": 9, "opt_str_list": ["one", "two"], "req_str_list": ["spike", "lee"], - "opt_json": '{"ford": "green"}' + "opt_json": '{"ford": "green"}', } cls = DummyMessage(**_dict) assert cls.verify() - assert _eq(cls.keys(), ['opt_str', 'req_str', 'opt_json', - 'req_str_list', 'opt_str_list', 'opt_int']) + assert _eq( + cls.keys(), + ["opt_str", "req_str", "opt_json", "req_str_list", "opt_str_list", "opt_int"], + ) _dict = { - "req_str": "Fair", "opt_str": "game", "opt_int": 9, + "req_str": "Fair", + "opt_str": "game", + "opt_int": 9, "opt_str_list": ["one", "two"], "req_str_list": ["spike", "lee"], - "opt_json": '{"ford": "green"}', "extra": "internal" + "opt_json": '{"ford": "green"}', + "extra": "internal", } cls = DummyMessage(**_dict) assert cls.verify() - assert _eq(cls.keys(), ['opt_str', 'req_str', 'extra', 'opt_json', - 'req_str_list', 'opt_str_list', 'opt_int']) + assert _eq( + cls.keys(), + ["opt_str", "req_str", "extra", "opt_json", "req_str_list", "opt_str_list", "opt_int"], + ) _dict = { - "req_str": "Fair", "opt_str": "game", "opt_int": 9, + "req_str": "Fair", + "opt_str": "game", + "opt_int": 9, "opt_str_list": ["one", "two"], - "req_str_list": ["spike", "lee"] + "req_str_list": ["spike", "lee"], } cls = DummyMessage(**_dict) cls.verify() - assert _eq(cls.keys(), ['opt_str', 'req_str', 'req_str_list', - 'opt_str_list', 'opt_int']) + assert _eq(cls.keys(), ["opt_str", "req_str", "req_str_list", "opt_str_list", "opt_int"]) def test_request(self): - req = DummyMessage(req_str="Fair", - req_str_list=["game"]).request("http://example.com") - assert url_compare(req, - "http://example.com?req_str=Fair&req_str_list=game") + req = DummyMessage(req_str="Fair", req_str_list=["game"]).request("http://example.com") + assert url_compare(req, "http://example.com?req_str=Fair&req_str_list=game") def test_get(self): _dict = { - "req_str": "Fair", "opt_str": "game", "opt_int": 9, + "req_str": "Fair", + "opt_str": "game", + "opt_int": 9, "opt_str_list": ["one", "two"], "req_str_list": ["spike", "lee"], - "opt_json": '{"ford": "green"}' + "opt_json": '{"ford": "green"}', } cls = DummyMessage(**_dict) @@ -286,41 +309,41 @@ def test_get(self): def test_int_instead_of_string(self): with pytest.raises(ValueError): - DummyMessage(req_str=2, req_str_list=['foo']) + DummyMessage(req_str=2, req_str_list=["foo"]) -@pytest.mark.parametrize("keytype,alg", [ - ('RSA', 'RS256'), - ('EC', 'ES256') -]) +@pytest.mark.parametrize("keytype,alg", [("RSA", "RS256"), ("EC", "ES256")]) def test_to_jwt(keytype, alg): - msg = Message(a='foo', b='bar', c='tjoho') - _jwt = msg.to_jwt(KEYJAR.get_signing_key(keytype, ''), alg) + msg = Message(a="foo", b="bar", c="tjoho") + _jwt = msg.to_jwt(KEYJAR.get_signing_key(keytype, ""), alg) msg1 = Message().from_jwt(_jwt, KEYJAR) assert msg1 == msg -@pytest.mark.parametrize("keytype,alg,enc", [ - ('RSA', 'RSA1_5', 'A128CBC-HS256'), - ('EC', 'ECDH-ES', 'A128GCM'), -]) +@pytest.mark.parametrize( + "keytype,alg,enc", + [ + ("RSA", "RSA1_5", "A128CBC-HS256"), + ("EC", "ECDH-ES", "A128GCM"), + ], +) def test_to_jwe(keytype, alg, enc): - msg = Message(a='foo', b='bar', c='tjoho') - _jwe = msg.to_jwe(KEYJAR.get_encrypt_key(keytype, ''), alg=alg, enc=enc) - msg1 = Message().from_jwe(_jwe, KEYJAR.get_encrypt_key(keytype, '')) + msg = Message(a="foo", b="bar", c="tjoho") + _jwe = msg.to_jwe(KEYJAR.get_encrypt_key(keytype, ""), alg=alg, enc=enc) + msg1 = Message().from_jwe(_jwe, KEYJAR.get_encrypt_key(keytype, "")) assert msg1 == msg def test_to_dict_with_message_obj(): - content = Message(a={'a': {'foo': {'bar': [{'bat': []}]}}}) + content = Message(a={"a": {"foo": {"bar": [{"bat": []}]}}}) _dict = content.to_dict(lev=0) - content_fixture = {'a': {'a': {'foo': {'bar': [{'bat': []}]}}}} + content_fixture = {"a": {"a": {"foo": {"bar": [{"bat": []}]}}}} assert _dict == content_fixture def test_to_dict_with_raw_types(): msg = Message(c_default=[]) - content_fixture = {'c_default': []} + content_fixture = {"c_default": []} _dict = msg.to_dict(lev=1) assert _dict == content_fixture @@ -332,22 +355,19 @@ class MsgMessage(Message): "opt_str": SINGLE_OPTIONAL_STRING, } - _dict = { - "req_str": "Fair", "req_str_list": ["spike", "lee"], - "opt_int": 9 - } + _dict = {"req_str": "Fair", "req_str_list": ["spike", "lee"], "opt_int": 9} _msg = DummyMessage() _msg.from_dict(_dict) msg = MsgMessage() - msg['msg'] = _msg - msg['opt_str'] = 'string' + msg["msg"] = _msg + msg["opt_str"] = "string" mjson = msg.to_json() mm = MsgMessage().from_json(mjson) - assert mm['opt_str'] == 'string' - assert set(mm['msg'].keys()) == set(_msg.keys()) + assert mm["opt_str"] == "string" + assert set(mm["msg"].keys()) == set(_msg.keys()) def test_msg_list_deserializer(): @@ -357,23 +377,20 @@ class MsgMessage(Message): "opt_str": SINGLE_OPTIONAL_STRING, } - _dict = { - "req_str": "Fair", "req_str_list": ["spike", "lee"], - "opt_int": 9 - } + _dict = {"req_str": "Fair", "req_str_list": ["spike", "lee"], "opt_int": 9} _msg = DummyMessage() _msg.from_dict(_dict) msg = MsgMessage() - msg['msgs'] = [_msg] - msg['opt_str'] = 'string' + msg["msgs"] = [_msg] + msg["opt_str"] = "string" mjson = msg.to_json() mm = MsgMessage().from_json(mjson) - assert mm['opt_str'] == 'string' - assert len(mm['msgs']) == 1 - assert set(mm['msgs'][0].keys()) == set(_msg.keys()) + assert mm["opt_str"] == "string" + assert len(mm["msgs"]) == 1 + assert set(mm["msgs"][0].keys()) == set(_msg.keys()) def test_msg_list_deserializer_dict(): @@ -383,21 +400,18 @@ class MsgMessage(Message): "opt_str": SINGLE_OPTIONAL_STRING, } - _dict = { - "req_str": "Fair", "req_str_list": ["spike", "lee"], - "opt_int": 9 - } + _dict = {"req_str": "Fair", "req_str_list": ["spike", "lee"], "opt_int": 9} msg = MsgMessage() - msg['msgs'] = _dict - msg['opt_str'] = 'string' + msg["msgs"] = _dict + msg["opt_str"] = "string" mjson = msg.to_json() mm = MsgMessage().from_json(mjson) - assert mm['opt_str'] == 'string' - assert len(mm['msgs']) == 1 - assert set(mm['msgs'][0].keys()) == set(_dict.keys()) + assert mm["opt_str"] == "string" + assert len(mm["msgs"]) == 1 + assert set(mm["msgs"][0].keys()) == set(_dict.keys()) def test_msg_list_deserializer_url(): @@ -407,21 +421,18 @@ class MsgMessage(Message): "opt_str": SINGLE_OPTIONAL_STRING, } - _dict = { - "req_str": "Fair", "req_str_list": ["spike", "lee"], - "opt_int": 9 - } + _dict = {"req_str": "Fair", "req_str_list": ["spike", "lee"], "opt_int": 9} _msg = DummyMessage(**_dict) msg = MsgMessage() with pytest.raises(DecodeError): - msg['msgs'] = [_msg.to_urlencoded()] + msg["msgs"] = [_msg.to_urlencoded()] def test_add_value(): with pytest.raises(ValueError): - DummyMessage(req_str=['1', '2']) + DummyMessage(req_str=["1", "2"]) def test_type_check(): @@ -436,17 +447,19 @@ def test_json_type_error(): val = '{"key":"A byte string"}' m = Message() m.from_json(val) - assert 'key' in m + assert "key" in m -@pytest.mark.parametrize("keytype,alg,enc", [ - ('RSA', 'RSA1_5', 'A128CBC-HS256'), - ('EC', 'ECDH-ES', 'A128GCM'), -]) +@pytest.mark.parametrize( + "keytype,alg,enc", + [ + ("RSA", "RSA1_5", "A128CBC-HS256"), + ("EC", "ECDH-ES", "A128GCM"), + ], +) def test_to_jwe(keytype, alg, enc): - msg = Message(a='foo', b='bar', c='tjoho') - _jwe = msg.to_jwe(KEYJAR.get_encrypt_key(keytype, ''), alg=alg, - enc=enc) + msg = Message(a="foo", b="bar", c="tjoho") + _jwe = msg.to_jwe(KEYJAR.get_encrypt_key(keytype, ""), alg=alg, enc=enc) with pytest.raises(HeaderError): Message().from_jwt(_jwe, KEYJAR, encalg="RSA-OAEP", encenc=enc) with pytest.raises(HeaderError): @@ -458,41 +471,41 @@ def test_to_jwe(keytype, alg, enc): k = new_rsa_key() NEW_KID = k.kid kb.append(k) -NEW_KEYJAR.add_kb('', kb) +NEW_KEYJAR.add_kb("", kb) def test_no_suitable_keys(): - keytype = 'RSA' - alg = 'RS256' - msg = Message(a='foo', b='bar', c='tjoho') - _jwt = msg.to_jwt(NEW_KEYJAR.get_signing_key(keytype, '', kid=NEW_KID), alg) + keytype = "RSA" + alg = "RS256" + msg = Message(a="foo", b="bar", c="tjoho") + _jwt = msg.to_jwt(NEW_KEYJAR.get_signing_key(keytype, "", kid=NEW_KID), alg) with pytest.raises(NoSuitableSigningKeys): Message().from_jwt(_jwt, KEYJAR) def test_only_extras(): - m = DummyMessage(foo='bar', extra='value') + m = DummyMessage(foo="bar", extra="value") assert m.only_extras() - m['req_str'] = 'string' + m["req_str"] = "string" assert m.only_extras() is False def test_weed(): - m = DummyMessage(foo='bar', extra='value') - m['req_str'] = 'string' + m = DummyMessage(foo="bar", extra="value") + m["req_str"] = "string" - assert set(m.keys()) == {'req_str', 'foo', 'extra'} + assert set(m.keys()) == {"req_str", "foo", "extra"} m.weed() - assert set(m.keys()) == {'req_str'} + assert set(m.keys()) == {"req_str"} def test_msg_ser(): - assert msg_ser('a.b.c', 'dict') == 'a.b.c' + assert msg_ser("a.b.c", "dict") == "a.b.c" with pytest.raises(MessageException): - msg_ser([1, 2], 'dict') + msg_ser([1, 2], "dict") with pytest.raises(OidcMsgError): - msg_ser([1, 2], 'list') + msg_ser([1, 2], "list") def test_error_description(): diff --git a/tests/test_05_oauth2.py b/tests/test_05_oauth2.py index 773fc3f..10dfe81 100755 --- a/tests/test_05_oauth2.py +++ b/tests/test_05_oauth2.py @@ -26,7 +26,7 @@ from oidcmsg.oauth2 import factory from oidcmsg.oauth2 import is_error_message -__author__ = 'Roland Hedberg' +__author__ = "Roland Hedberg" keys = [ {"type": "RSA", "use": ["sig"]}, @@ -43,14 +43,14 @@ KEYJAR = build_keyjar(keys) IKEYJAR = build_keyjar(keys) -IKEYJAR.import_jwks(IKEYJAR.export_jwks(private=True), 'issuer') -del IKEYJAR[''] +IKEYJAR.import_jwks(IKEYJAR.export_jwks(private=True), "issuer") +del IKEYJAR[""] KEYJARS = {} -for iss in ['A', 'B', 'C']: +for iss in ["A", "B", "C"]: _kj = build_keyjar(keym) - _kj.import_jwks(_kj.export_jwks(private=True) ,iss) - del _kj[''] + _kj.import_jwks(_kj.export_jwks(private=True), iss) + del _kj[""] KEYJARS[iss] = _kj @@ -111,52 +111,69 @@ def test_authz_req_urlencoded(self): assert query_string_compare(ue, "response_type=code&client_id=foobar") def test_urlencoded_with_redirect_uri(self): - ar = AuthorizationRequest(response_type=["code"], client_id="foobar", - redirect_uri="http://foobar.example.com/oaclient", - state="cold") + ar = AuthorizationRequest( + response_type=["code"], + client_id="foobar", + redirect_uri="http://foobar.example.com/oaclient", + state="cold", + ) ue = ar.to_urlencoded() - assert query_string_compare(ue, - "state=cold&redirect_uri=http%3A%2F%2Ffoobar.example.com" - "%2Foaclient&" - "response_type=code&client_id=foobar") + assert query_string_compare( + ue, + "state=cold&redirect_uri=http%3A%2F%2Ffoobar.example.com" + "%2Foaclient&" + "response_type=code&client_id=foobar", + ) def test_urlencoded_resp_type_token(self): - ar = AuthorizationRequest(response_type=["token"], - client_id="s6BhdRkqt3", - redirect_uri="https://client.example.com/cb", - state="xyz") + ar = AuthorizationRequest( + response_type=["token"], + client_id="s6BhdRkqt3", + redirect_uri="https://client.example.com/cb", + state="xyz", + ) ue = ar.to_urlencoded() - assert query_string_compare(ue, - "state=xyz&redirect_uri=https%3A%2F%2Fclient.example.com%2Fcb" - "&response_type=token&" - "client_id=s6BhdRkqt3") + assert query_string_compare( + ue, + "state=xyz&redirect_uri=https%3A%2F%2Fclient.example.com%2Fcb" + "&response_type=token&" + "client_id=s6BhdRkqt3", + ) def test_deserialize_urlencoded(self): - ar = AuthorizationRequest(response_type=["code"], - client_id="foobar") + ar = AuthorizationRequest(response_type=["code"], client_id="foobar") urlencoded = ar.to_urlencoded() ar2 = AuthorizationRequest().deserialize(urlencoded, "urlencoded") assert ar == ar2 def test_urlencoded_with_scope(self): - ar = AuthorizationRequest(response_type=["code"], client_id="foobar", - redirect_uri="http://foobar.example.com/oaclient", - scope=["foo", "bar"], state="cold") + ar = AuthorizationRequest( + response_type=["code"], + client_id="foobar", + redirect_uri="http://foobar.example.com/oaclient", + scope=["foo", "bar"], + state="cold", + ) ue = ar.to_urlencoded() - assert query_string_compare(ue, - "scope=foo+bar&state=cold&redirect_uri=http%3A%2F%2Ffoobar" - ".example.com%2Foaclient&" - "response_type=code&client_id=foobar") + assert query_string_compare( + ue, + "scope=foo+bar&state=cold&redirect_uri=http%3A%2F%2Ffoobar" + ".example.com%2Foaclient&" + "response_type=code&client_id=foobar", + ) def test_deserialize_urlencoded_multiple_params(self): - ar = AuthorizationRequest(response_type=["code"], - client_id="foobar", - redirect_uri="http://foobar.example.com/oaclient", - scope=["foo", "bar"], state="cold") + ar = AuthorizationRequest( + response_type=["code"], + client_id="foobar", + redirect_uri="http://foobar.example.com/oaclient", + scope=["foo", "bar"], + state="cold", + ) urlencoded = ar.to_urlencoded() ar2 = AuthorizationRequest().deserialize(urlencoded, "urlencoded") @@ -169,83 +186,90 @@ def test_urlencoded_missing_required(self): def test_urlencoded_invalid_scope(self): args = { - "response_type": [10], "client_id": "foobar", + "response_type": [10], + "client_id": "foobar", "redirect_uri": "http://foobar.example.com/oaclient", - "scope": ["foo", "bar"], "state": "cold" + "scope": ["foo", "bar"], + "state": "cold", } with pytest.raises(DecodeError): AuthorizationRequest(**args) def test_urlencoded_deserialize_state(self): - txt = "scope=foo+bar&state=-11&redirect_uri=http%3A%2F%2Ffoobar" \ - ".example.com%2Foaclient&response_type=code&" \ - "client_id=foobar" + txt = ( + "scope=foo+bar&state=-11&redirect_uri=http%3A%2F%2Ffoobar" + ".example.com%2Foaclient&response_type=code&" + "client_id=foobar" + ) ar = AuthorizationRequest().deserialize(txt, "urlencoded") assert ar["state"] == "-11" def test_urlencoded_deserialize_response_type(self): - txt = "scope=openid&state=id-6a3fc96caa7fd5cb1c7d00ed66937134&" \ - "redirect_uri=http%3A%2F%2Flocalhost%3A8087authz&response_type" \ - "=code&client_id=a1b2c3" + txt = ( + "scope=openid&state=id-6a3fc96caa7fd5cb1c7d00ed66937134&" + "redirect_uri=http%3A%2F%2Flocalhost%3A8087authz&response_type" + "=code&client_id=a1b2c3" + ) ar = AuthorizationRequest().deserialize(txt, "urlencoded") assert ar["scope"] == ["openid"] assert ar["response_type"] == ["code"] def test_req_json_serialize(self): - ar = AuthorizationRequest(response_type=["code"], - client_id="foobar") + ar = AuthorizationRequest(response_type=["code"], client_id="foobar") js_obj = json.loads(ar.serialize(method="json")) - expected_js_obj = { - "response_type": "code", - "client_id": "foobar" - } + expected_js_obj = {"response_type": "code", "client_id": "foobar"} assert js_obj == expected_js_obj def test_json_multiple_params(self): - ar = AuthorizationRequest(response_type=["code"], - client_id="foobar", - redirect_uri="http://foobar.example.com/oaclient", - state="cold") + ar = AuthorizationRequest( + response_type=["code"], + client_id="foobar", + redirect_uri="http://foobar.example.com/oaclient", + state="cold", + ) ue_obj = json.loads(ar.serialize(method="json")) expected_ue_obj = { "response_type": "code", "state": "cold", "redirect_uri": "http://foobar.example.com/oaclient", - "client_id": "foobar" + "client_id": "foobar", } assert ue_obj == expected_ue_obj def test_json_resp_type_token(self): - ar = AuthorizationRequest(response_type=["token"], - client_id="s6BhdRkqt3", - redirect_uri="https://client.example.com/cb", - state="xyz") + ar = AuthorizationRequest( + response_type=["token"], + client_id="s6BhdRkqt3", + redirect_uri="https://client.example.com/cb", + state="xyz", + ) ue_obj = json.loads(ar.serialize(method="json")) expected_ue_obj = { "state": "xyz", "redirect_uri": "https://client.example.com/cb", "response_type": "token", - "client_id": "s6BhdRkqt3" + "client_id": "s6BhdRkqt3", } assert ue_obj == expected_ue_obj def test_json_serialize_deserialize(self): - ar = AuthorizationRequest(response_type=["code"], - client_id="foobar") + ar = AuthorizationRequest(response_type=["code"], client_id="foobar") jtxt = ar.serialize(method="json") ar2 = AuthorizationRequest().deserialize(jtxt, "json") assert ar == ar2 def test_verify(self): - query = 'redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fauthz' \ - '&response_type=code&client_id=0123456789' + query = ( + "redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fauthz" + "&response_type=code&client_id=0123456789" + ) ar = AuthorizationRequest().deserialize(query, "urlencoded") assert ar.verify() @@ -255,7 +279,7 @@ def test_load_dict(self): "state": "id-6da9ca0cc23959f5f33e8becd9b08cae", "redirect_uri": "http://localhost:8087authz", "response_type": ["code"], - "client_id": "a1b2c3" + "client_id": "a1b2c3", } arq = AuthorizationRequest(**bib) @@ -272,7 +296,7 @@ def test_json_serizalize_deserialize_multiple_params(self): "state": "id-b0be8bb64118c3ec5f70093a1174b039", "redirect_uri": "http://localhost:8087authz", "response_type": ["code"], - "client_id": "a1b2c3" + "client_id": "a1b2c3", } arq = AuthorizationRequest(**argv) @@ -286,13 +310,11 @@ def test_json_serizalize_deserialize_multiple_params(self): assert jarq["client_id"] == "a1b2c3" def test_multiple_response_types_urlencoded(self): - ar = AuthorizationRequest(response_type=["code", "token"], - client_id="foobar") + ar = AuthorizationRequest(response_type=["code", "token"], client_id="foobar") ue = ar.to_urlencoded() - ue_splits = ue.split('&') - expected_ue_splits = "response_type=code+token&client_id=foobar".split( - '&') + ue_splits = ue.split("&") + expected_ue_splits = "response_type=code+token&client_id=foobar".split("&") assert _eq(ue_splits, expected_ue_splits) are = AuthorizationRequest().deserialize(ue, "urlencoded") @@ -300,14 +322,14 @@ def test_multiple_response_types_urlencoded(self): assert _eq(are["response_type"], ["code", "token"]) def test_multiple_scopes_urlencoded(self): - ar = AuthorizationRequest(response_type=["code", "token"], - client_id="foobar", - scope=["openid", "foxtrot"]) + ar = AuthorizationRequest( + response_type=["code", "token"], client_id="foobar", scope=["openid", "foxtrot"] + ) ue = ar.to_urlencoded() - ue_splits = ue.split('&') - expected_ue_splits = "scope=openid+foxtrot&response_type=code+token" \ - "&client_id=foobar".split( - '&') + ue_splits = ue.split("&") + expected_ue_splits = ( + "scope=openid+foxtrot&response_type=code+token" "&client_id=foobar".split("&") + ) assert _eq(ue_splits, expected_ue_splits) are = AuthorizationRequest().deserialize(ue, "urlencoded") @@ -316,14 +338,10 @@ def test_multiple_scopes_urlencoded(self): assert _eq(are["scope"], ["openid", "foxtrot"]) def test_multiple_response_types_json(self): - ar = AuthorizationRequest(response_type=["code", "token"], - client_id="foobar") + ar = AuthorizationRequest(response_type=["code", "token"], client_id="foobar") ue = ar.to_json() ue_obj = json.loads(ue) - expected_ue_obj = { - "response_type": "code token", - "client_id": "foobar" - } + expected_ue_obj = {"response_type": "code token", "client_id": "foobar"} assert ue_obj == expected_ue_obj are = AuthorizationRequest().deserialize(ue, "json") @@ -331,15 +349,15 @@ def test_multiple_response_types_json(self): assert _eq(are["response_type"], ["code", "token"]) def test_multiple_scopes_json(self): - ar = AuthorizationRequest(response_type=["code", "token"], - client_id="foobar", - scope=["openid", "foxtrot"]) + ar = AuthorizationRequest( + response_type=["code", "token"], client_id="foobar", scope=["openid", "foxtrot"] + ) ue = ar.to_json() ue_obj = json.loads(ue) expected_ue_obj = { "scope": "openid foxtrot", "response_type": "code token", - "client_id": "foobar" + "client_id": "foobar", } assert ue_obj == expected_ue_obj @@ -351,16 +369,14 @@ def test_multiple_scopes_json(self): class TestAuthorizationErrorResponse(object): def test_init(self): - aer = AuthorizationErrorResponse(error="access_denied", - state="xyz") + aer = AuthorizationErrorResponse(error="access_denied", state="xyz") assert aer["error"] == "access_denied" assert aer["state"] == "xyz" def test_extra_params(self): - aer = AuthorizationErrorResponse(error="access_denied", - error_description="brewers has a " - "four game series", - foo="bar") + aer = AuthorizationErrorResponse( + error="access_denied", error_description="brewers has a " "four game series", foo="bar" + ) assert aer["error"] == "access_denied" assert aer["error_description"] == "brewers has a four game series" assert aer["foo"] == "bar" @@ -374,10 +390,9 @@ def test_init(self): assert ter["state"] == "xyz" def test_extra_params(self): - ter = TokenErrorResponse(error="access_denied", - error_description="brewers has a four game " - "series", - foo="bar") + ter = TokenErrorResponse( + error="access_denied", error_description="brewers has a four game " "series", foo="bar" + ) assert ter["error"] == "access_denied" assert ter["error_description"] == "brewers has a four game series" @@ -386,15 +401,14 @@ def test_extra_params(self): class TestAccessTokenResponse(object): def test_json_serialize(self): - at = AccessTokenResponse(access_token="SlAV32hkKG", - token_type="Bearer", expires_in=3600) + at = AccessTokenResponse(access_token="SlAV32hkKG", token_type="Bearer", expires_in=3600) atj = at.serialize(method="json") atj_obj = json.loads(atj) expected_atj_obj = { "token_type": "Bearer", "access_token": "SlAV32hkKG", - "expires_in": 3600 + "expires_in": 3600, } assert atj_obj == expected_atj_obj @@ -405,7 +419,8 @@ def test_multiple_scope(self): expires_in=3600, refresh_token="tGzv3JOkF0XG5Qx2TlKWIA", example_parameter="example_value", - scope=["inner", "outer"]) + scope=["inner", "outer"], + ) assert _eq(atr["scope"], ["inner", "outer"]) @@ -421,48 +436,75 @@ def test_to_urlencoded_extended_omit(self): example_parameter="example_value", scope=["inner", "outer"], extra=["local", "external"], - level=3) + level=3, + ) uec = atr.to_urlencoded() - assert query_string_compare(uec, - "scope=inner+outer&level=3&expires_in=3600&token_type=example" - "&extra=local&" - "extra=external&refresh_token=tGzv3JOkF0XG5Qx2TlKWIA&" - "access_token=2YotnFZFEjr1zCsicMWpAA&example_parameter" - "=example_value") + assert query_string_compare( + uec, + "scope=inner+outer&level=3&expires_in=3600&token_type=example" + "&extra=local&" + "extra=external&refresh_token=tGzv3JOkF0XG5Qx2TlKWIA&" + "access_token=2YotnFZFEjr1zCsicMWpAA&example_parameter" + "=example_value", + ) del atr["extra"] ouec = atr.to_urlencoded() - assert query_string_compare(ouec, - "access_token=2YotnFZFEjr1zCsicMWpAA&refresh_token=tGzv3JOkF0XG5Qx2TlKWIA&" - "level=3&example_parameter=example_value&token_type=example" - "&expires_in=3600&" - "scope=inner+outer") - assert len(uec) == (len(ouec) + len("extra=local") + - len("extra=external") + 2) + assert query_string_compare( + ouec, + "access_token=2YotnFZFEjr1zCsicMWpAA&refresh_token=tGzv3JOkF0XG5Qx2TlKWIA&" + "level=3&example_parameter=example_value&token_type=example" + "&expires_in=3600&" + "scope=inner+outer", + ) + assert len(uec) == (len(ouec) + len("extra=local") + len("extra=external") + 2) atr2 = AccessTokenResponse().deserialize(uec, "urlencoded") - assert _eq(atr2.keys(), ['access_token', 'expires_in', 'token_type', - 'scope', 'refresh_token', 'level', - 'example_parameter', 'extra']) + assert _eq( + atr2.keys(), + [ + "access_token", + "expires_in", + "token_type", + "scope", + "refresh_token", + "level", + "example_parameter", + "extra", + ], + ) atr3 = AccessTokenResponse().deserialize(ouec, "urlencoded") - assert _eq(atr3.keys(), ['access_token', 'expires_in', 'token_type', - 'scope', 'refresh_token', 'level', - 'example_parameter']) + assert _eq( + atr3.keys(), + [ + "access_token", + "expires_in", + "token_type", + "scope", + "refresh_token", + "level", + "example_parameter", + ], + ) class TestAccessTokenRequest(object): def test_extra(self): - atr = AccessTokenRequest(grant_type="authorization_code", - code="SplxlOBeZQQYbYS6WxSbIA", - redirect_uri="https://client.example.com/cb", - extra="foo") + atr = AccessTokenRequest( + grant_type="authorization_code", + code="SplxlOBeZQQYbYS6WxSbIA", + redirect_uri="https://client.example.com/cb", + extra="foo", + ) query = atr.to_urlencoded() - assert query_string_compare(query, - "code=SplxlOBeZQQYbYS6WxSbIA&redirect_uri=https%3A%2F%2Fclient.example.com%2Fcb&" - "grant_type=authorization_code&extra=foo") + assert query_string_compare( + query, + "code=SplxlOBeZQQYbYS6WxSbIA&redirect_uri=https%3A%2F%2Fclient.example.com%2Fcb&" + "grant_type=authorization_code&extra=foo", + ) atr2 = AccessTokenRequest().deserialize(query, "urlencoded") assert atr == atr2 @@ -470,8 +512,7 @@ def test_extra(self): class TestAuthorizationResponse(object): def test_init(self): - atr = AuthorizationResponse(code="SplxlOBeZQQYbYS6WxSbIA", - state="Fun_state", extra="foo") + atr = AuthorizationResponse(code="SplxlOBeZQQYbYS6WxSbIA", state="Fun_state", extra="foo") assert atr["code"] == "SplxlOBeZQQYbYS6WxSbIA" assert atr["state"] == "Fun_state" @@ -480,8 +521,7 @@ def test_init(self): class TestROPCAccessTokenRequest(object): def test_init(self): - ropc = ROPCAccessTokenRequest(grant_type="password", - username="johndoe", password="A3ddj3w") + ropc = ROPCAccessTokenRequest(grant_type="password", username="johndoe", password="A3ddj3w") assert ropc["grant_type"] == "password" assert ropc["username"] == "johndoe" @@ -497,8 +537,7 @@ def test_init(self): class TestRefreshAccessTokenRequest(object): def test_init(self): - ratr = RefreshAccessTokenRequest(refresh_token="ababababab", - client_id="Client_id") + ratr = RefreshAccessTokenRequest(refresh_token="ababababab", client_id="Client_id") assert ratr["grant_type"] == "refresh_token" assert ratr["refresh_token"] == "ababababab" assert ratr["client_id"] == "Client_id" @@ -511,17 +550,11 @@ def test_init(self): request = TokenExchangeRequest( grant_type="urn:ietf:params:oauth:grant-type:token-exchange", subject_token="ababababab", - subject_token_type="urn:ietf:params:oauth:token-type:access_token" - ) - assert ( - request["grant_type"] - == "urn:ietf:params:oauth:grant-type:token-exchange" + subject_token_type="urn:ietf:params:oauth:token-type:access_token", ) + assert request["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange" assert request["subject_token"] == "ababababab" - assert ( - request["subject_token_type"] - == "urn:ietf:params:oauth:token-type:access_token" - ) + assert request["subject_token_type"] == "urn:ietf:params:oauth:token-type:access_token" assert request.verify() @@ -532,12 +565,9 @@ def test_init(self): access_token="bababababa", issued_token_type="urn:ietf:params:oauth:token-type:access_token", token_type="Bearer", - expires_in=60 - ) - assert ( - response["issued_token_type"] - == "urn:ietf:params:oauth:token-type:access_token" + expires_in=60, ) + assert response["issued_token_type"] == "urn:ietf:params:oauth:token-type:access_token" assert response["access_token"] == "bababababa" assert response["token_type"] == "Bearer" assert response["expires_in"] == 60 @@ -547,9 +577,11 @@ def test_init(self): class TestResponseMessage_error(object): def test_error_message(self): - err = ResponseMessage(error="invalid_request", - error_description="Something was missing", - error_uri="http://example.com/error_message.html") + err = ResponseMessage( + error="invalid_request", + error_description="Something was missing", + error_uri="http://example.com/error_message.html", + ) ue_str = err.to_urlencoded() del err["error_uri"] @@ -561,26 +593,26 @@ def test_error_message(self): assert is_error_message(err) def test_auth_error_message(self): - resp = AuthorizationResponse(error="invalid_request", - error_description="Something was missing") + resp = AuthorizationResponse( + error="invalid_request", error_description="Something was missing" + ) assert is_error_message(resp) def test_factory(): - dr = factory('ResponseMessage', error='some_error') + dr = factory("ResponseMessage", error="some_error") assert isinstance(dr, ResponseMessage) - assert list(dr.keys()) == ['error'] + assert list(dr.keys()) == ["error"] def test_factory_auth_response(): - ar = factory('AuthorizationResponse', client_id='client1', iss='Issuer', - code='1234567') + ar = factory("AuthorizationResponse", client_id="client1", iss="Issuer", code="1234567") assert isinstance(ar, AuthorizationResponse) - assert ar.verify(client_id='client1', iss='Issuer') + assert ar.verify(client_id="client1", iss="Issuer") def test_set_default(): ar = AccessTokenRequest(set_defaults=False) assert list(ar.keys()) == [] ar.set_defaults() - assert 'grant_type' in ar + assert "grant_type" in ar diff --git a/tests/test_06_oidc.py b/tests/test_06_oidc.py index 1e70d5e..ca17c2a 100755 --- a/tests/test_06_oidc.py +++ b/tests/test_06_oidc.py @@ -59,20 +59,22 @@ from oidcmsg.oidc import verify_id_token from oidcmsg.time_util import utc_time_sans_frac -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - '..', '..'))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) -__author__ = 'Roland Hedberg' +__author__ = "Roland Hedberg" CLIENT_ID = "client_1" -IDTOKEN = IdToken(iss="http://oic.example.org/", sub="sub", - aud=CLIENT_ID, exp=utc_time_sans_frac() + 300, - nonce="N0nce", iat=time.time()) +IDTOKEN = IdToken( + iss="http://oic.example.org/", + sub="sub", + aud=CLIENT_ID, + exp=utc_time_sans_frac() + 300, + nonce="N0nce", + iat=time.time(), +) KC_SYM_S = KeyBundle( - { - "kty": "oct", "key": "abcdefghijklmnop".encode("utf-8"), "use": "sig", - "alg": "HS256" - }) + {"kty": "oct", "key": "abcdefghijklmnop".encode("utf-8"), "use": "sig", "alg": "HS256"} +) def query_string_compare(query_str1, query_str2): @@ -89,37 +91,42 @@ def test_openidschema(): assert ois.verify() is False -@pytest.mark.parametrize("json_param", [ - '{"middle_name":"fo", "updated_at":"20170328081544Z", "sub":"abc"}', - '{"middle_name":true, "updated_at":"20170328081544", "sub":"abc"}', - '{"middle_name":"fo", "updated_at":false, "sub":"abc"}', - '{"middle_name":"fo", "updated_at":"20170328081544Z", "sub":true}' -]) +@pytest.mark.parametrize( + "json_param", + [ + '{"middle_name":"fo", "updated_at":"20170328081544Z", "sub":"abc"}', + '{"middle_name":true, "updated_at":"20170328081544", "sub":"abc"}', + '{"middle_name":"fo", "updated_at":false, "sub":"abc"}', + '{"middle_name":"fo", "updated_at":"20170328081544Z", "sub":true}', + ], +) def test_openidschema_from_json(json_param): with pytest.raises(ValueError): OpenIDSchema().from_json(json_param) -@pytest.mark.parametrize("json_param", [ - '{"email_verified":false, "email":"foo@example.com", "sub":"abc"}', - '{"email_verified":true, "email":"foo@example.com", "sub":"abc"}', - '{"phone_number_verified":false, "phone_number":"+1 555 200000", ' - '"sub":"abc"}', - '{"phone_number_verified":true, "phone_number":"+1 555 20000", ' - '"sub":"abc"}', -]) +@pytest.mark.parametrize( + "json_param", + [ + '{"email_verified":false, "email":"foo@example.com", "sub":"abc"}', + '{"email_verified":true, "email":"foo@example.com", "sub":"abc"}', + '{"phone_number_verified":false, "phone_number":"+1 555 200000", ' '"sub":"abc"}', + '{"phone_number_verified":true, "phone_number":"+1 555 20000", ' '"sub":"abc"}', + ], +) def test_claim_booleans(json_param): assert OpenIDSchema().from_json(json_param) -@pytest.mark.parametrize("json_param", [ - '{"email_verified":"Not", "email":"foo@example.com", "sub":"abc"}', - '{"email_verified":"Sure", "email":"foo@example.com", "sub":"abc"}', - '{"phone_number_verified":"Not", "phone_number":"+1 555 200000", ' - '"sub":"abc"}', - '{"phone_number_verified":"Sure", "phone_number":"+1 555 20000", ' - '"sub":"abc"}', -]) +@pytest.mark.parametrize( + "json_param", + [ + '{"email_verified":"Not", "email":"foo@example.com", "sub":"abc"}', + '{"email_verified":"Sure", "email":"foo@example.com", "sub":"abc"}', + '{"phone_number_verified":"Not", "phone_number":"+1 555 200000", ' '"sub":"abc"}', + '{"phone_number_verified":"Sure", "phone_number":"+1 555 20000", ' '"sub":"abc"}', + ], +) def test_claim_not_booleans(json_param): with pytest.raises(ValueError): OpenIDSchema().from_json(json_param) @@ -133,12 +140,12 @@ def test_claims_deser(): "email": {"essential": True}, "email_verified": {"essential": True}, "picture": None, - "http://example.info/claims/groups": None + "http://example.info/claims/groups": None, }, "id_token": { "auth_time": {"essential": True}, - "acr": {"values": ["urn:mace:incommon:iap:silver"]} - } + "acr": {"values": ["urn:mace:incommon:iap:silver"]}, + }, } claims = claims_deser(json.dumps(_dic), sformat="json") @@ -146,336 +153,333 @@ def test_claims_deser(): def test_claims_deser_dict(): - pre = Claims(name={"essential": True}, nickname=None, - email={"essential": True}, - email_verified={"essential": True}, picture=None) + pre = Claims( + name={"essential": True}, + nickname=None, + email={"essential": True}, + email_verified={"essential": True}, + picture=None, + ) claims = claims_deser(pre.to_json(), sformat="json") - assert _eq(claims.keys(), ['name', 'nickname', 'email', 'email_verified', - 'picture']) + assert _eq(claims.keys(), ["name", "nickname", "email", "email_verified", "picture"]) claims = claims_deser(pre.to_dict(), sformat="dict") - assert _eq(claims.keys(), ['name', 'nickname', 'email', 'email_verified', - 'picture']) + assert _eq(claims.keys(), ["name", "nickname", "email", "email_verified", "picture"]) def test_address_deser(): - pre = AddressClaim(street_address="Kasamark 114", locality="Umea", - country="Sweden") + pre = AddressClaim(street_address="Kasamark 114", locality="Umea", country="Sweden") adc = address_deser(pre.to_json(), sformat="json") - assert _eq(adc.keys(), ['street_address', 'locality', 'country']) + assert _eq(adc.keys(), ["street_address", "locality", "country"]) adc = address_deser(pre.to_dict(), sformat="json") - assert _eq(adc.keys(), ['street_address', 'locality', 'country']) + assert _eq(adc.keys(), ["street_address", "locality", "country"]) def test_msg_ser_json(): - pre = AddressClaim(street_address="Kasamark 114", locality="Umea", - country="Sweden") + pre = AddressClaim(street_address="Kasamark 114", locality="Umea", country="Sweden") ser = msg_ser_json(pre, "json") adc = address_deser(ser, "json") - assert _eq(adc.keys(), ['street_address', 'locality', 'country']) + assert _eq(adc.keys(), ["street_address", "locality", "country"]) def test_msg_ser_json_from_dict(): - ser = msg_ser_json({ - 'street_address': "Kasamark 114", 'locality': "Umea", - 'country': "Sweden" - }, "json") + ser = msg_ser_json( + {"street_address": "Kasamark 114", "locality": "Umea", "country": "Sweden"}, "json" + ) adc = address_deser(ser, "json") - assert _eq(adc.keys(), ['street_address', 'locality', 'country']) + assert _eq(adc.keys(), ["street_address", "locality", "country"]) def test_msg_ser_json_to_dict(): - pre = AddressClaim(street_address="Kasamark 114", locality="Umea", - country="Sweden") + pre = AddressClaim(street_address="Kasamark 114", locality="Umea", country="Sweden") ser = msg_ser_json(pre, "dict") adc = address_deser(ser, "dict") - assert _eq(adc.keys(), ['street_address', 'locality', 'country']) + assert _eq(adc.keys(), ["street_address", "locality", "country"]) def test_msg_ser_dict_to_dict(): - pre = { - 'street_address': "Kasamark 114", 'locality': "Umea", - 'country': "Sweden" - } + pre = {"street_address": "Kasamark 114", "locality": "Umea", "country": "Sweden"} ser = msg_ser_json(pre, "dict") adc = address_deser(ser, "dict") - assert _eq(adc.keys(), ['street_address', 'locality', 'country']) + assert _eq(adc.keys(), ["street_address", "locality", "country"]) def test_msg_ser_urlencoded(): - pre = AddressClaim(street_address="Kasamark 114", locality="Umea", - country="Sweden") + pre = AddressClaim(street_address="Kasamark 114", locality="Umea", country="Sweden") ser = msg_ser(pre.to_dict(), "urlencoded") adc = address_deser(ser, "urlencoded") - assert _eq(adc.keys(), ['street_address', 'locality', 'country']) + assert _eq(adc.keys(), ["street_address", "locality", "country"]) def test_msg_ser_dict(): - pre = AddressClaim(street_address="Kasamark 114", locality="Umea", - country="Sweden") + pre = AddressClaim(street_address="Kasamark 114", locality="Umea", country="Sweden") ser = msg_ser(pre.to_dict(), "dict") adc = address_deser(ser, "dict") - assert _eq(adc.keys(), ['street_address', 'locality', 'country']) + assert _eq(adc.keys(), ["street_address", "locality", "country"]) def test_msg_ser_from_dict(): - pre = { - "street_address": "Kasamark 114", "locality": "Umea", - "country": "Sweden" - } + pre = {"street_address": "Kasamark 114", "locality": "Umea", "country": "Sweden"} ser = msg_ser(pre, "dict") adc = address_deser(ser, "dict") - assert _eq(adc.keys(), ['street_address', 'locality', 'country']) + assert _eq(adc.keys(), ["street_address", "locality", "country"]) def test_claims_ser_json(): - claims = Claims(name={"essential": True}, nickname=None, - email={"essential": True}, - email_verified={"essential": True}, picture=None) + claims = Claims( + name={"essential": True}, + nickname=None, + email={"essential": True}, + email_verified={"essential": True}, + picture=None, + ) claims = claims_deser(claims_ser(claims, "json"), sformat="json") - assert _eq(claims.keys(), ['name', 'nickname', 'email', 'email_verified', - 'picture']) + assert _eq(claims.keys(), ["name", "nickname", "email", "email_verified", "picture"]) def test_claims_ser_from_dict_to_json(): - claims = claims_ser({ - "name": {"essential": True}, - "nickname": None, - "email": {"essential": True}, - "email_verified": {"essential": True}, - "picture": None - }, sformat="json") + claims = claims_ser( + { + "name": {"essential": True}, + "nickname": None, + "email": {"essential": True}, + "email_verified": {"essential": True}, + "picture": None, + }, + sformat="json", + ) cl = Claims().from_json(claims) - assert _eq(cl.keys(), ['name', 'nickname', 'email', 'email_verified', - 'picture']) + assert _eq(cl.keys(), ["name", "nickname", "email", "email_verified", "picture"]) def test_claims_ser_from_dict_to_urlencoded(): - claims = claims_ser({ - "name": {"essential": True}, - "nickname": None, - "email": {"essential": True}, - "email_verified": {"essential": True}, - "picture": None - }, sformat="urlencoded") + claims = claims_ser( + { + "name": {"essential": True}, + "nickname": None, + "email": {"essential": True}, + "email_verified": {"essential": True}, + "picture": None, + }, + sformat="urlencoded", + ) cl = Claims().from_urlencoded(claims) - assert _eq(cl.keys(), ['name', 'nickname', 'email', 'email_verified', - 'picture']) + assert _eq(cl.keys(), ["name", "nickname", "email", "email_verified", "picture"]) def test_claims_ser_from_dict_to_dict(): - claims = claims_ser({ - "name": {"essential": True}, - "nickname": None, - "email": {"essential": True}, - "email_verified": {"essential": True}, - "picture": None - }, sformat="dict") + claims = claims_ser( + { + "name": {"essential": True}, + "nickname": None, + "email": {"essential": True}, + "email_verified": {"essential": True}, + "picture": None, + }, + sformat="dict", + ) cl = Claims(**claims) - assert _eq(cl.keys(), ['name', 'nickname', 'email', 'email_verified', - 'picture']) + assert _eq(cl.keys(), ["name", "nickname", "email", "email_verified", "picture"]) def test_claims_ser_from_dict_to_foo(): with pytest.raises(OidcMsgError): - _ = claims_ser({ - "name": {"essential": True}, - "nickname": None, - "email": {"essential": True}, - "email_verified": {"essential": True}, - "picture": None - }, sformat="foo") + _ = claims_ser( + { + "name": {"essential": True}, + "nickname": None, + "email": {"essential": True}, + "email_verified": {"essential": True}, + "picture": None, + }, + sformat="foo", + ) def test_claims_ser_wrong_type(): with pytest.raises(MessageException): - _ = claims_ser(json.dumps({ - "name": {"essential": True}, - "nickname": None, - "email": {"essential": True}, - "email_verified": {"essential": True}, - "picture": None - }), sformat="dict") + _ = claims_ser( + json.dumps( + { + "name": {"essential": True}, + "nickname": None, + "email": {"essential": True}, + "email_verified": {"essential": True}, + "picture": None, + } + ), + sformat="dict", + ) def test_discovery_request(): - request = { - 'rel': "http://openid.net/specs/connect/1.0/issuer", - 'resource': 'diana@localhost' - } + request = {"rel": "http://openid.net/specs/connect/1.0/issuer", "resource": "diana@localhost"} req = DiscoveryRequest().from_json(json.dumps(request)) - assert set(req.keys()) == {'rel', 'resource'} + assert set(req.keys()) == {"rel", "resource"} def test_discovery_response(): - link = Link(href='https://example.com/op', - rel="http://openid.net/specs/connect/1.0/issuer") + link = Link(href="https://example.com/op", rel="http://openid.net/specs/connect/1.0/issuer") - resp = JRD(subject='diana@localhost', links=[link]) + resp = JRD(subject="diana@localhost", links=[link]) - assert set(resp.keys()) == {'subject', 'links'} + assert set(resp.keys()) == {"subject", "links"} def test_link_ser1(): - link = Link(href='https://example.com/op', - rel="http://openid.net/specs/connect/1.0/issuer") - _js = link_ser(link, 'json') + link = Link(href="https://example.com/op", rel="http://openid.net/specs/connect/1.0/issuer") + _js = link_ser(link, "json") _lnk = json.loads(_js) - assert set(_lnk.keys()) == {'href', 'rel'} + assert set(_lnk.keys()) == {"href", "rel"} def test_link_ser_dict(): - info = { - 'href': 'https://example.com/op', - 'rel': "http://openid.net/specs/connect/1.0/issuer" - } - _js = link_ser(info, 'json') + info = {"href": "https://example.com/op", "rel": "http://openid.net/specs/connect/1.0/issuer"} + _js = link_ser(info, "json") _lnk = json.loads(_js) - assert set(_lnk.keys()) == {'href', 'rel'} + assert set(_lnk.keys()) == {"href", "rel"} - _ue = link_ser(info, 'urlencoded') + _ue = link_ser(info, "urlencoded") assert _ue res = parse_qs(_ue) - assert set(res.keys()) == {'href', 'rel'} + assert set(res.keys()) == {"href", "rel"} class TestProviderConfigurationResponse(object): def test_deserialize(self): resp = { - "authorization_endpoint": - "https://server.example.com/connect/authorize", + "authorization_endpoint": "https://server.example.com/connect/authorize", "issuer": "https://server.example.com", "token_endpoint": "https://server.example.com/connect/token", - "token_endpoint_auth_methods_supported": ["client_secret_basic", - "private_key_jwt"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"], "userinfo_endpoint": "https://server.example.com/connect/user", "check_id_endpoint": "https://server.example.com/connect/check_id", - "refresh_session_endpoint": - "https://server.example.com/connect/refresh_session", - "end_session_endpoint": - "https://server.example.com/connect/end_session", + "refresh_session_endpoint": "https://server.example.com/connect/refresh_session", + "end_session_endpoint": "https://server.example.com/connect/end_session", "jwk_url": "https://server.example.com/jwk.json", - "registration_endpoint": - "https://server.example.com/connect/register", - "scopes_supported": ["openid", "profile", "email", "address", - "phone"], - "response_types_supported": ["code", "code id_token", - "token id_token"], - "acrs_supported": ["1", "2", - "http://id.incommon.org/assurance/bronze"], + "registration_endpoint": "https://server.example.com/connect/register", + "scopes_supported": ["openid", "profile", "email", "address", "phone"], + "response_types_supported": ["code", "code id_token", "token id_token"], + "acrs_supported": ["1", "2", "http://id.incommon.org/assurance/bronze"], "user_id_types_supported": ["public", "pairwise"], - "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", - "RSA1_5"], - "id_token_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", - "RSA1_5"], - "request_object_algs_supported": ["HS256", "RS256", "A128CBC", - "A128KW", - "RSA1_5"] + "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], + "id_token_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], + "request_object_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], } - pcr = ProviderConfigurationResponse().deserialize(json.dumps(resp), - "json") + pcr = ProviderConfigurationResponse().deserialize(json.dumps(resp), "json") # Missing subject_types_supported with pytest.raises(MissingRequiredAttribute): assert pcr.verify() assert _eq(pcr["user_id_types_supported"], ["public", "pairwise"]) - assert _eq(pcr["acrs_supported"], - ["1", "2", "http://id.incommon.org/assurance/bronze"]) + assert _eq(pcr["acrs_supported"], ["1", "2", "http://id.incommon.org/assurance/bronze"]) def test_example_response(self): resp = { "version": "3.0", "issuer": "https://server.example.com", - "authorization_endpoint": - "https://server.example.com/connect/authorize", + "authorization_endpoint": "https://server.example.com/connect/authorize", "token_endpoint": "https://server.example.com/connect/token", - "token_endpoint_auth_methods_supported": ["client_secret_basic", - "private_key_jwt"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"], "token_endpoint_alg_values_supported": ["RS256", "ES256"], "userinfo_endpoint": "https://server.example.com/connect/userinfo", - "check_session_iframe": - "https://server.example.com/connect/check_session", - "end_session_endpoint": - "https://server.example.com/connect/end_session", + "check_session_iframe": "https://server.example.com/connect/check_session", + "end_session_endpoint": "https://server.example.com/connect/end_session", "jwks_uri": "https://server.example.com/jwks.json", - "registration_endpoint": - "https://server.example.com/connect/register", - "scopes_supported": ["openid", "profile", "email", "address", - "phone", "offline_access"], - "response_types_supported": ["code", "code id_token", "id_token", - "token id_token"], - "acr_values_supported": ["urn:mace:incommon:iap:silver", - "urn:mace:incommon:iap:bronze"], + "registration_endpoint": "https://server.example.com/connect/register", + "scopes_supported": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + ], + "response_types_supported": ["code", "code id_token", "id_token", "token id_token"], + "acr_values_supported": [ + "urn:mace:incommon:iap:silver", + "urn:mace:incommon:iap:bronze", + ], "subject_types_supported": ["public", "pairwise"], - "userinfo_signing_alg_values_supported": ["RS256", "ES256", - "HS256"], + "userinfo_signing_alg_values_supported": ["RS256", "ES256", "HS256"], "userinfo_encryption_alg_values_supported": ["RSA1_5", "A128KW"], - "userinfo_encryption_enc_values_supported": ["A128CBC+HS256", - "A128GCM"], - "id_token_signing_alg_values_supported": ["RS256", "ES256", - "HS256"], + "userinfo_encryption_enc_values_supported": ["A128CBC+HS256", "A128GCM"], + "id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"], "id_token_encryption_alg_values_supported": ["RSA1_5", "A128KW"], - "id_token_encryption_enc_values_supported": ["A128CBC+HS256", - "A128GCM"], - "request_object_signing_alg_values_supported": ["none", "RS256", - "ES256"], + "id_token_encryption_enc_values_supported": ["A128CBC+HS256", "A128GCM"], + "request_object_signing_alg_values_supported": ["none", "RS256", "ES256"], "display_values_supported": ["page", "popup"], "claim_types_supported": ["normal", "distributed"], - "claims_supported": ["sub", "iss", "auth_time", "acr", "name", - "given_name", "family_name", "nickname", - "profile", - "picture", "website", "email", - "email_verified", - "locale", "zoneinfo", - "http://example.info/claims/groups"], + "claims_supported": [ + "sub", + "iss", + "auth_time", + "acr", + "name", + "given_name", + "family_name", + "nickname", + "profile", + "picture", + "website", + "email", + "email_verified", + "locale", + "zoneinfo", + "http://example.info/claims/groups", + ], "claims_parameter_supported": True, - "service_documentation": - "http://server.example.com/connect/service_documentation.html", - "ui_locales_supported": ["en-US", "en-GB", "en-CA", "fr-FR", - "fr-CA"] + "service_documentation": "http://server.example.com/connect/service_documentation.html", + "ui_locales_supported": ["en-US", "en-GB", "en-CA", "fr-FR", "fr-CA"], } - pcr = ProviderConfigurationResponse().deserialize(json.dumps(resp), - "json") + pcr = ProviderConfigurationResponse().deserialize(json.dumps(resp), "json") assert pcr.verify() rk = list(resp.keys()) # parameters with default value if missing - rk.extend(["grant_types_supported", "request_parameter_supported", - "request_uri_parameter_supported", - "require_request_uri_registration"]) + rk.extend( + [ + "grant_types_supported", + "request_parameter_supported", + "request_uri_parameter_supported", + "require_request_uri_registration", + ] + ) assert sorted(rk) == sorted(list(pcr.keys())) - @pytest.mark.parametrize("required_param", [ - "issuer", - "authorization_endpoint", - "jwks_uri", - "response_types_supported", - "subject_types_supported", - "id_token_signing_alg_values_supported" - ]) + @pytest.mark.parametrize( + "required_param", + [ + "issuer", + "authorization_endpoint", + "jwks_uri", + "response_types_supported", + "subject_types_supported", + "id_token_signing_alg_values_supported", + ], + ) def test_required_parameters(self, required_param): provider_config = { "issuer": "https://server.example.com", - "authorization_endpoint": - "https://server.example.com/connect/authorize", + "authorization_endpoint": "https://server.example.com/connect/authorize", "jwks_uri": "https://server.example.com/jwks.json", "response_types_supported": ["code", "code id_token", "id_token", "token id_token"], "subject_types_supported": ["public", "pairwise"], @@ -489,8 +493,7 @@ def test_required_parameters(self, required_param): def test_token_endpoint_is_not_required_for_implicit_flow_only(self): provider_config = { "issuer": "https://server.example.com", - "authorization_endpoint": - "https://server.example.com/connect/authorize", + "authorization_endpoint": "https://server.example.com/connect/authorize", "jwks_uri": "https://server.example.com/jwks.json", "response_types_supported": ["id_token", "token id_token"], "subject_types_supported": ["public", "pairwise"], @@ -503,13 +506,11 @@ def test_token_endpoint_is_not_required_for_implicit_flow_only(self): def test_token_endpoint_is_required_for_other_than_implicit_flow_only(self): provider_config = { "issuer": "https://server.example.com", - "authorization_endpoint": - "https://server.example.com/connect/authorize", + "authorization_endpoint": "https://server.example.com/connect/authorize", "jwks_uri": "https://server.example.com/jwks.json", "response_types_supported": ["code", "id_token"], "subject_types_supported": ["public", "pairwise"], - "id_token_signing_alg_values_supported": ["RS256", "ES256", - "HS256"], + "id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"], } with pytest.raises(MissingRequiredAttribute): @@ -520,102 +521,136 @@ class TestRegistrationRequest(object): def test_deserialize(self): msg = { "application_type": "web", - "redirect_uris": ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", "client_name#ja-Jpan-JP": "クライアント名", "logo_uri": "https://client.example.org/logo.png", "subject_type": "pairwise", - "sector_identifier_uri": - "https://other.example.net/file_of_redirect_uris.json", + "sector_identifier_uri": "https://other.example.net/file_of_redirect_uris.json", "token_endpoint_auth_method": "client_secret_basic", "jwks_uri": "https://client.example.org/my_public_keys.jwks", "userinfo_encrypted_response_alg": "RSA1_5", "userinfo_encrypted_response_enc": "A128CBC+HS256", "contacts": ["ve7jtb@example.org", "mary@example.org"], "request_uris": [ - "https://client.example.org/rf.txt" - "#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA"] + "https://client.example.org/rf.txt" "#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA" + ], } reg = RegistrationRequest().deserialize(json.dumps(msg), "json") assert reg.verify() - assert _eq(list(msg.keys()) + ['response_types'], reg.keys()) + assert _eq(list(msg.keys()) + ["response_types"], reg.keys()) def test_registration_request(self): - req = RegistrationRequest(operation="register", default_max_age=10, - require_auth_time=True, default_acr="foo", - application_type="web", - redirect_uris=["https://example.com/authz_cb"]) + req = RegistrationRequest( + operation="register", + default_max_age=10, + require_auth_time=True, + default_acr="foo", + application_type="web", + redirect_uris=["https://example.com/authz_cb"], + ) assert req.verify() js = req.to_json() js_obj = json.loads(js) expected_js_obj = { "redirect_uris": ["https://example.com/authz_cb"], - "application_type": "web", "default_acr": "foo", - "require_auth_time": True, "operation": "register", - "default_max_age": 10, "response_types": ["code"] + "application_type": "web", + "default_acr": "foo", + "require_auth_time": True, + "operation": "register", + "default_max_age": 10, + "response_types": ["code"], } assert js_obj == expected_js_obj - flattened_list_dict = {k: v[0] if isinstance(v, list) else v for k, v in - expected_js_obj.items()} - assert query_string_compare(req.to_urlencoded(), - urlencode(flattened_list_dict)) - - @pytest.mark.parametrize("enc_param", [ - "request_object_encryption_enc", - "id_token_encrypted_response_enc", - "userinfo_encrypted_response_enc", - ]) + flattened_list_dict = { + k: v[0] if isinstance(v, list) else v for k, v in expected_js_obj.items() + } + assert query_string_compare(req.to_urlencoded(), urlencode(flattened_list_dict)) + + @pytest.mark.parametrize( + "enc_param", + [ + "request_object_encryption_enc", + "id_token_encrypted_response_enc", + "userinfo_encrypted_response_enc", + ], + ) def test_registration_request_with_coupled_encryption_params(self, enc_param): registration_params = { - "redirect_uris": ["https://example.com/authz_cb"], enc_param: "RS256" + "redirect_uris": ["https://example.com/authz_cb"], + enc_param: "RS256", } registration_req = RegistrationRequest(**registration_params) with pytest.raises(MissingRequiredAttribute): registration_req.verify() def test_deser(self): - req = RegistrationRequest(operation="register", default_max_age=10, - require_auth_time=True, default_acr="foo", - application_type="web", - redirect_uris=[ - "https://example.com/authz_cb"]) - ser_req = req.serialize('urlencoded') + req = RegistrationRequest( + operation="register", + default_max_age=10, + require_auth_time=True, + default_acr="foo", + application_type="web", + redirect_uris=["https://example.com/authz_cb"], + ) + ser_req = req.serialize("urlencoded") deser_req = registration_request_deser(ser_req) - assert set(deser_req.keys()) == {'operation', 'default_max_age', - 'require_auth_time', 'default_acr', - 'application_type', 'redirect_uris', - 'response_types'} + assert set(deser_req.keys()) == { + "operation", + "default_max_age", + "require_auth_time", + "default_acr", + "application_type", + "redirect_uris", + "response_types", + } def test_deser_dict(self): req = { - 'operation': "register", 'default_max_age': 10, - 'require_auth_time': True, 'default_acr': "foo", - 'application_type': "web", - 'redirect_uris': ["https://example.com/authz_cb"] + "operation": "register", + "default_max_age": 10, + "require_auth_time": True, + "default_acr": "foo", + "application_type": "web", + "redirect_uris": ["https://example.com/authz_cb"], } - deser_req = registration_request_deser(req, 'dict') - assert set(deser_req.keys()) == {'operation', 'default_max_age', - 'require_auth_time', 'default_acr', - 'application_type', 'redirect_uris', - 'response_types'} + deser_req = registration_request_deser(req, "dict") + assert set(deser_req.keys()) == { + "operation", + "default_max_age", + "require_auth_time", + "default_acr", + "application_type", + "redirect_uris", + "response_types", + } def test_deser_dict_json(self): req = { - 'operation': "register", 'default_max_age': 10, - 'require_auth_time': True, 'default_acr': "foo", - 'application_type': "web", - 'redirect_uris': ["https://example.com/authz_cb"] + "operation": "register", + "default_max_age": 10, + "require_auth_time": True, + "default_acr": "foo", + "application_type": "web", + "redirect_uris": ["https://example.com/authz_cb"], } - deser_req = registration_request_deser(req, 'json') - assert set(deser_req.keys()) == {'operation', 'default_max_age', - 'require_auth_time', 'default_acr', - 'application_type', 'redirect_uris', - 'response_types'} + deser_req = registration_request_deser(req, "json") + assert set(deser_req.keys()) == { + "operation", + "default_max_age", + "require_auth_time", + "default_acr", + "application_type", + "redirect_uris", + "response_types", + } class TestRegistrationResponse(object): @@ -625,26 +660,26 @@ def test_deserialize(self): "client_secret": "ZJYCqe3GGRvdrudKyZS0XhGv_Z45DuKhCUk0gBR1vZk", "client_secret_expires_at": 1577858400, "registration_access_token": "this.is.an.access.token.value.ffx83", - "registration_client_uri": - "https://server.example.com/connect/register?client_id" - "=s6BhdRkqt3", + "registration_client_uri": "https://server.example.com/connect/register?client_id" + "=s6BhdRkqt3", "token_endpoint_auth_method": "client_secret_basic", "application_type": "web", - "redirect_uris": ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", "client_name#ja-Jpan-JP": "クライアント名", "logo_uri": "https://client.example.org/logo.png", "subject_type": "pairwise", - "sector_identifier_uri": - "https://other.example.net/file_of_redirect_uris.json", + "sector_identifier_uri": "https://other.example.net/file_of_redirect_uris.json", "jwks_uri": "https://client.example.org/my_public_keys.jwks", "userinfo_encrypted_response_alg": "RSA1_5", "userinfo_encrypted_response_enc": "A128CBC+HS256", "contacts": ["ve7jtb@example.org", "mary@example.org"], "request_uris": [ - "https://client.example.org/rf.txt" - "#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA"] + "https://client.example.org/rf.txt" "#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA" + ], } resp = RegistrationResponse().deserialize(json.dumps(msg), "json") @@ -654,16 +689,18 @@ def test_deserialize(self): class TestAuthorizationRequest(object): def test_deserialize(self): - query = "response_type=token%20id_token&client_id=0acf77d4-b486-4c99" \ - "-bd76-074ed6a64ddf&redirect_uri=https%3A%2F%2Fclient.example" \ - ".com%2Fcb&scope=openid%20profile&state=af0ifjsldkj&nonce=n" \ - "-0S6_WzA2Mj" + query = ( + "response_type=token%20id_token&client_id=0acf77d4-b486-4c99" + "-bd76-074ed6a64ddf&redirect_uri=https%3A%2F%2Fclient.example" + ".com%2Fcb&scope=openid%20profile&state=af0ifjsldkj&nonce=n" + "-0S6_WzA2Mj" + ) req = AuthorizationRequest().deserialize(query, "urlencoded") - assert _eq(req.keys(), - ['nonce', 'state', 'redirect_uri', 'response_type', - 'client_id', 'scope']) + assert _eq( + req.keys(), ["nonce", "state", "redirect_uri", "response_type", "client_id", "scope"] + ) assert req["response_type"] == ["token", "id_token"] assert req["scope"] == ["openid", "profile"] @@ -683,40 +720,38 @@ def test_verify_nonce(self): "client_id": "foobar", "redirect_uri": "http://foobar.example.com/oaclient", "response_type": ["code", "id_token"], - "scope": "openid" + "scope": "openid", } ar = AuthorizationRequest(**args) with pytest.raises(MissingRequiredAttribute): ar.verify() - ar['nonce'] = 'abcdefgh' + ar["nonce"] = "abcdefgh" assert ar.verify() with pytest.raises(ValueError): - assert ar.verify(nonce='12345678') + assert ar.verify(nonce="12345678") def test_claims(self): args = { "client_id": "foobar", "redirect_uri": "http://foobar.example.com/oaclient", "response_type": "code", - 'scope': 'openid', - 'claims': { - "userinfo": - { - "given_name": {"essential": True}, - "nickname": None, - "email": {"essential": True}, - "email_verified": {"essential": True}, - "picture": None, - "http://example.info/claims/groups": None - }, - "id_token": - { - "auth_time": {"essential": True}, - "acr": {"values": ["urn:mace:incommon:iap:silver"]} - } - } + "scope": "openid", + "claims": { + "userinfo": { + "given_name": {"essential": True}, + "nickname": None, + "email": {"essential": True}, + "email_verified": {"essential": True}, + "picture": None, + "http://example.info/claims/groups": None, + }, + "id_token": { + "auth_time": {"essential": True}, + "acr": {"values": ["urn:mace:incommon:iap:silver"]}, + }, + }, } ar = AuthorizationRequest(**args) assert ar.verify() @@ -736,40 +771,45 @@ def test_request(self): "response_type": "code", "scope": "openid", "nonce": "some value", - "extra": 'attribute' + "extra": "attribute", } ar = AuthorizationRequest(**args) keyjar = KeyJar() - keyjar.add_symmetric('', "SomeTestPassword") - keyjar.add_symmetric('foobar', "SomeTestPassword") - _signed_jwt = make_openid_request(ar, keyjar, 'foobar', 'HS256', 'barfoo') - ar['request'] = _signed_jwt - del ar['nonce'] - del ar['extra'] - ar['scope'] = ['openid', 'email'] + keyjar.add_symmetric("", "SomeTestPassword") + keyjar.add_symmetric("foobar", "SomeTestPassword") + _signed_jwt = make_openid_request(ar, keyjar, "foobar", "HS256", "barfoo") + ar["request"] = _signed_jwt + del ar["nonce"] + del ar["extra"] + ar["scope"] = ["openid", "email"] res = ar.verify(keyjar=keyjar) assert res - assert 'extra' in ar - assert 'nonce' in ar - assert ar['scope'] == ['openid'] + assert "extra" in ar + assert "nonce" in ar + assert ar["scope"] == ["openid"] class TestAccessTokenResponse(object): def test_ok_idtoken(self): idval = { - 'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject', - 'iss': 'https://alpha.cloud.nds.rub.de', 'aud': 'TestClient' + "nonce": "KUEYfRM2VzKDaaKD", + "sub": "EndUserSubject", + "iss": "https://alpha.cloud.nds.rub.de", + "aud": "TestClient", } idts = IdToken(**idval) keyjar = KeyJar() - keyjar.add_symmetric('', "SomeTestPassword") - keyjar.add_symmetric('https://alpha.cloud.nds.rub.de', "SomeTestPassword") - _signed_jwt = idts.to_jwt(key=keyjar.get_signing_key('oct'), algorithm="HS256", - lifetime=300) + keyjar.add_symmetric("", "SomeTestPassword") + keyjar.add_symmetric("https://alpha.cloud.nds.rub.de", "SomeTestPassword") + _signed_jwt = idts.to_jwt( + key=keyjar.get_signing_key("oct"), algorithm="HS256", lifetime=300 + ) _info = { - "access_token": "accessTok", "id_token": _signed_jwt, - "token_type": "Bearer", "expires_in": 3600 + "access_token": "accessTok", + "id_token": _signed_jwt, + "token_type": "Bearer", + "expires_in": 3600, } at = AccessTokenResponse(**_info) @@ -777,24 +817,28 @@ def test_ok_idtoken(self): def test_faulty_idtoken(self): idval = { - 'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject', - 'iss': 'https://alpha.cloud.nds.rub.de', 'aud': 'TestClient' + "nonce": "KUEYfRM2VzKDaaKD", + "sub": "EndUserSubject", + "iss": "https://alpha.cloud.nds.rub.de", + "aud": "TestClient", } idts = IdToken(**idval) keyjar = KeyJar() - keyjar.add_symmetric('', "SomeTestPassword") - keyjar.add_symmetric('https://alpha.cloud.nds.rub.de', - "SomeTestPassword") - _signed_jwt = idts.to_jwt(key=keyjar.get_signing_key('oct'), - algorithm="HS256", lifetime=300) + keyjar.add_symmetric("", "SomeTestPassword") + keyjar.add_symmetric("https://alpha.cloud.nds.rub.de", "SomeTestPassword") + _signed_jwt = idts.to_jwt( + key=keyjar.get_signing_key("oct"), algorithm="HS256", lifetime=300 + ) # Mess with the signed id_token p = _signed_jwt.split(".") p[2] = "aaa" _faulty_signed_jwt = ".".join(p) _info = { - "access_token": "accessTok", "id_token": _faulty_signed_jwt, - "token_type": "Bearer", "expires_in": 3600 + "access_token": "accessTok", + "id_token": _faulty_signed_jwt, + "token_type": "Bearer", + "expires_in": 3600, } at = AccessTokenResponse(**_info) @@ -803,20 +847,24 @@ def test_faulty_idtoken(self): def test_wrong_alg(self): idval = { - 'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject', - 'iss': 'https://alpha.cloud.nds.rub.de', 'aud': 'TestClient' + "nonce": "KUEYfRM2VzKDaaKD", + "sub": "EndUserSubject", + "iss": "https://alpha.cloud.nds.rub.de", + "aud": "TestClient", } idts = IdToken(**idval) keyjar = KeyJar() - keyjar.add_symmetric('', "SomeTestPassword") - keyjar.add_symmetric('https://alpha.cloud.nds.rub.de', - "SomeTestPassword") - _signed_jwt = idts.to_jwt(key=keyjar.get_signing_key('oct'), - algorithm="HS256", lifetime=300) + keyjar.add_symmetric("", "SomeTestPassword") + keyjar.add_symmetric("https://alpha.cloud.nds.rub.de", "SomeTestPassword") + _signed_jwt = idts.to_jwt( + key=keyjar.get_signing_key("oct"), algorithm="HS256", lifetime=300 + ) _info = { - "access_token": "accessTok", "id_token": _signed_jwt, - "token_type": "Bearer", "expires_in": 3600 + "access_token": "accessTok", + "id_token": _signed_jwt, + "token_type": "Bearer", + "expires_in": 3600, } at = AccessTokenResponse(**_info) @@ -826,84 +874,81 @@ def test_wrong_alg(self): def test_at_hash(): lifetime = 3600 - _token = {'access_token': 'accessTok'} + _token = {"access_token": "accessTok"} idval = { - 'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject', - 'iss': 'https://alpha.cloud.nds.rub.de', 'aud': 'TestClient' + "nonce": "KUEYfRM2VzKDaaKD", + "sub": "EndUserSubject", + "iss": "https://alpha.cloud.nds.rub.de", + "aud": "TestClient", } idval.update(_token) idts = IdToken(**idval) keyjar = KeyJar() - keyjar.add_symmetric('', "SomeTestPassword") - keyjar.add_symmetric('https://alpha.cloud.nds.rub.de', - "SomeTestPassword") - _signed_jwt = idts.to_jwt(key=keyjar.get_signing_key('oct'), - algorithm="HS256", lifetime=lifetime) - - _info = { - "id_token": _signed_jwt, "token_type": "Bearer", - "expires_in": lifetime - } + keyjar.add_symmetric("", "SomeTestPassword") + keyjar.add_symmetric("https://alpha.cloud.nds.rub.de", "SomeTestPassword") + _signed_jwt = idts.to_jwt( + key=keyjar.get_signing_key("oct"), algorithm="HS256", lifetime=lifetime + ) + + _info = {"id_token": _signed_jwt, "token_type": "Bearer", "expires_in": lifetime} _info.update(_token) at = AuthorizationResponse(**_info) assert at.verify(keyjar=keyjar, sigalg="HS256") - assert 'at_hash' in at[verified_claim_name('id_token')] + assert "at_hash" in at[verified_claim_name("id_token")] def test_c_hash(): lifetime = 3600 - _token = {'code': 'grant'} + _token = {"code": "grant"} idval = { - 'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject', - 'iss': 'https://alpha.cloud.nds.rub.de', 'aud': 'TestClient' + "nonce": "KUEYfRM2VzKDaaKD", + "sub": "EndUserSubject", + "iss": "https://alpha.cloud.nds.rub.de", + "aud": "TestClient", } idval.update(_token) idts = IdToken(**idval) keyjar = KeyJar() - keyjar.add_symmetric('', "SomeTestPassword") - keyjar.add_symmetric('https://alpha.cloud.nds.rub.de', - "SomeTestPassword") - _signed_jwt = idts.to_jwt(key=keyjar.get_signing_key('oct'), - algorithm="HS256", lifetime=lifetime) - - _info = { - "id_token": _signed_jwt, "token_type": "Bearer", - "expires_in": lifetime - } + keyjar.add_symmetric("", "SomeTestPassword") + keyjar.add_symmetric("https://alpha.cloud.nds.rub.de", "SomeTestPassword") + _signed_jwt = idts.to_jwt( + key=keyjar.get_signing_key("oct"), algorithm="HS256", lifetime=lifetime + ) + + _info = {"id_token": _signed_jwt, "token_type": "Bearer", "expires_in": lifetime} _info.update(_token) at = AuthorizationResponse(**_info) r = at.verify(keyjar=keyjar, sigalg="HS256") - assert 'c_hash' in at[verified_claim_name('id_token')] + assert "c_hash" in at[verified_claim_name("id_token")] def test_missing_c_hash(): lifetime = 3600 - _token = {'code': 'grant'} + _token = {"code": "grant"} idval = { - 'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject', - 'iss': 'https://alpha.cloud.nds.rub.de', 'aud': 'TestClient' + "nonce": "KUEYfRM2VzKDaaKD", + "sub": "EndUserSubject", + "iss": "https://alpha.cloud.nds.rub.de", + "aud": "TestClient", } # idval.update(_token) idts = IdToken(**idval) keyjar = KeyJar() - keyjar.add_symmetric('', "SomeTestPassword") - keyjar.add_symmetric('https://alpha.cloud.nds.rub.de', - "SomeTestPassword") + keyjar.add_symmetric("", "SomeTestPassword") + keyjar.add_symmetric("https://alpha.cloud.nds.rub.de", "SomeTestPassword") - _signed_jwt = idts.to_jwt(key=keyjar.get_signing_key('oct'), - algorithm="HS256", lifetime=lifetime) + _signed_jwt = idts.to_jwt( + key=keyjar.get_signing_key("oct"), algorithm="HS256", lifetime=lifetime + ) - _info = { - "id_token": _signed_jwt, "token_type": "Bearer", - "expires_in": lifetime - } + _info = {"id_token": _signed_jwt, "token_type": "Bearer", "expires_in": lifetime} _info.update(_token) at = AuthorizationResponse(**_info) @@ -914,19 +959,18 @@ def test_missing_c_hash(): def test_id_token(): _now = time_util.utc_time_sans_frac() - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - "at_hash": "L4Ign7TCAD_EppRbHAuCyw", - "iat": _now, - "exp": _now + 3600, - "iss": "https://sso.qa.7pass.ctf.prosiebensat1.com" - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + "at_hash": "L4Ign7TCAD_EppRbHAuCyw", + "iat": _now, + "exp": _now + 3600, + "iss": "https://sso.qa.7pass.ctf.prosiebensat1.com", + } + ) idt.verify() @@ -934,19 +978,18 @@ def test_id_token(): def test_id_token_expired(): _now = time_util.utc_time_sans_frac() - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - "at_hash": "L4Ign7TCAD_EppRbHAuCyw", - "iat": _now - 200, - "exp": _now - 100, - "iss": "https://sso.qa.7pass.ctf.prosiebensat1.com" - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + "at_hash": "L4Ign7TCAD_EppRbHAuCyw", + "iat": _now - 200, + "exp": _now - 100, + "iss": "https://sso.qa.7pass.ctf.prosiebensat1.com", + } + ) with pytest.raises(EXPError): idt.verify() @@ -955,19 +998,18 @@ def test_id_token_expired(): def test_id_token_iat_in_the_future(): _now = time_util.utc_time_sans_frac() - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - "at_hash": "L4Ign7TCAD_EppRbHAuCyw", - "iat": _now + 600, - "exp": _now + 1200, - "iss": "https://sso.qa.7pass.ctf.prosiebensat1.com" - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + "at_hash": "L4Ign7TCAD_EppRbHAuCyw", + "iat": _now + 600, + "exp": _now + 1200, + "iss": "https://sso.qa.7pass.ctf.prosiebensat1.com", + } + ) with pytest.raises(IATError): idt.verify() @@ -976,27 +1018,29 @@ def test_id_token_iat_in_the_future(): def test_id_token_exp_before_iat(): _now = time_util.utc_time_sans_frac() - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - "at_hash": "L4Ign7TCAD_EppRbHAuCyw", - "iat": _now + 50, - "exp": _now, - "iss": "https://sso.qa.7pass.ctf.prosiebensat1.com" - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + "at_hash": "L4Ign7TCAD_EppRbHAuCyw", + "iat": _now + 50, + "exp": _now, + "iss": "https://sso.qa.7pass.ctf.prosiebensat1.com", + } + ) with pytest.raises(IATError): idt.verify(skew=100) + class TestAccessTokenRequest(object): def test_example(self): - _txt = 'grant_type=authorization_code&code=SplxlOBeZQQYbYS6WxSbIA' \ - '&redirect_uri=https%3A%2F%2Fclient.example.org%2Fcb' + _txt = ( + "grant_type=authorization_code&code=SplxlOBeZQQYbYS6WxSbIA" + "&redirect_uri=https%3A%2F%2Fclient.example.org%2Fcb" + ) atr = AccessTokenRequest().from_urlencoded(_txt) assert atr.verify() @@ -1004,10 +1048,10 @@ def test_example(self): class TestAuthnToken(object): def test_example(self): at = AuthnToken( - iss='https://example.com', - sub='https://example,org', - aud=['https://example.org/token'], # Array of strings or string - jti='abcdefghijkl', + iss="https://example.com", + sub="https://example,org", + aud=["https://example.org/token"], # Array of strings or string + jti="abcdefghijkl", exp=utc_time_sans_frac() + 3600, ) assert at.verify() @@ -1015,497 +1059,545 @@ def test_example(self): class TestAuthorizationErrorResponse(object): def test_allowed_err(self): - aer = AuthorizationErrorResponse(error='interaction_required') + aer = AuthorizationErrorResponse(error="interaction_required") assert aer.verify() def test_not_allowed_err(self): - aer = AuthorizationErrorResponse(error='other_error') + aer = AuthorizationErrorResponse(error="other_error") with pytest.raises(NotAllowedValue): assert aer.verify() -@pytest.mark.parametrize("bdate", [ - "1971-11-23", "0000-11-23", "1971" -]) +@pytest.mark.parametrize("bdate", ["1971-11-23", "0000-11-23", "1971"]) def test_birthdate(bdate): - uinfo = OpenIDSchema(birthdate=bdate, sub='jarvis') + uinfo = OpenIDSchema(birthdate=bdate, sub="jarvis") uinfo.verify() def test_factory(): - dr = factory('DiscoveryRequest', resource='local@domain', - rel="http://openid.net/specs/connect/1.0/issuer") + dr = factory( + "DiscoveryRequest", + resource="local@domain", + rel="http://openid.net/specs/connect/1.0/issuer", + ) assert isinstance(dr, DiscoveryRequest) - assert set(dr.keys()) == {'resource', 'rel'} + assert set(dr.keys()) == {"resource", "rel"} def test_factory_chain(): - dr = factory('ResponseMessage', error='some_error') + dr = factory("ResponseMessage", error="some_error") assert isinstance(dr, ResponseMessage) - assert list(dr.keys()) == ['error'] + assert list(dr.keys()) == ["error"] def test_dict_deser(): - _info = {'foo': 'bar'} + _info = {"foo": "bar"} # supposed to output JSON - _jinfo = dict_deser(_info, 'dict') + _jinfo = dict_deser(_info, "dict") assert _jinfo == json.dumps(_info) - _jinfo2 = dict_deser(_jinfo, 'dict') + _jinfo2 = dict_deser(_jinfo, "dict") assert _jinfo == _jinfo2 with pytest.raises(ValueError): - _ = dict_deser(_info, 'foo') + _ = dict_deser(_info, "foo") def test_claims_match(): - assert claims_match(['val'], None) - assert claims_match('val', {'value': 'val'}) - assert claims_match('val', {'value': 'other'}) is False - assert claims_match('val', {'values': ['val', 'other']}) - assert claims_match('val', {'value': 'val', 'essential': True}) - assert claims_match('val', {'value': 'other', 'essential': True}) is False - assert claims_match('val', {'essential': True}) + assert claims_match(["val"], None) + assert claims_match("val", {"value": "val"}) + assert claims_match("val", {"value": "other"}) is False + assert claims_match("val", {"values": ["val", "other"]}) + assert claims_match("val", {"value": "val", "essential": True}) + assert claims_match("val", {"value": "other", "essential": True}) is False + assert claims_match("val", {"essential": True}) def test_factory_2(): - inst = factory('ROPCAccessTokenRequest', username='me', password='text', - scope='mar') + inst = factory("ROPCAccessTokenRequest", username="me", password="text", scope="mar") assert isinstance(inst, ROPCAccessTokenRequest) def test_link_deser(): - link = Link(href='https://example.com/op', - rel="http://openid.net/specs/connect/1.0/issuer") + link = Link(href="https://example.com/op", rel="http://openid.net/specs/connect/1.0/issuer") - jl = link_ser(link, 'json') - l2 = link_deser([jl], 'json') + jl = link_ser(link, "json") + l2 = link_deser([jl], "json") assert isinstance(l2[0], Link) def test_link_deser_dict(): - link = Link(href='https://example.com/op', - rel="http://openid.net/specs/connect/1.0/issuer") + link = Link(href="https://example.com/op", rel="http://openid.net/specs/connect/1.0/issuer") - l2 = link_deser([link.to_dict()], 'json') + l2 = link_deser([link.to_dict()], "json") assert isinstance(l2[0], Link) def test_proper_path(): - p = proper_path('foo/bar') - assert p == './foo/bar/' + p = proper_path("foo/bar") + assert p == "./foo/bar/" - p = proper_path('/foo/bar') - assert p == './foo/bar/' + p = proper_path("/foo/bar") + assert p == "./foo/bar/" - p = proper_path('./foo/bar') - assert p == './foo/bar/' + p = proper_path("./foo/bar") + assert p == "./foo/bar/" - p = proper_path('../foo/bar') - assert p == './foo/bar/' + p = proper_path("../foo/bar") + assert p == "./foo/bar/" def test_verify_id_token(): - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT( + kj, sign_alg="HS256", iss="https://sso.qa.7pass.ctf.prosiebensat1.com", lifetime=3600 + ) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(id_token=_jws) - vidt = verify_id_token(msg, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + vidt = verify_id_token( + msg, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) assert vidt def test_verify_id_token_wrong_issuer(): - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', iss="https://example.com/as", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT(kj, sign_alg="HS256", iss="https://example.com/as", lifetime=3600) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(id_token=_jws) with pytest.raises(ValueError): - verify_id_token(msg, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_verify_id_token_wrong_aud(): - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', iss="https://example.com/as", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT(kj, sign_alg="HS256", iss="https://example.com/as", lifetime=3600) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(id_token=_jws) with pytest.raises(ValueError): - verify_id_token(msg, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="aaaaaaaaaaaaaaaaaaaa") + verify_id_token( + msg, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="aaaaaaaaaaaaaaaaaaaa", + ) def test_verify_id_token_mismatch_aud_azp(): - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "aaaaaaaaaaaaaaaaaaaa", - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "aaaaaaaaaaaaaaaaaaaa", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', iss="https://example.com/as", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT(kj, sign_alg="HS256", iss="https://example.com/as", lifetime=3600) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(id_token=_jws) with pytest.raises(ValueError): - verify_id_token(msg, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="aaaaaaaaaaaaaaaaaaaa") + verify_id_token( + msg, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="aaaaaaaaaaaaaaaaaaaa", + ) def test_verify_id_token_c_hash(): - code = 'AccessCode1' + code = "AccessCode1" lhsh = left_hash(code) - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - "c_hash": lhsh - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + "c_hash": lhsh, + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT( + kj, sign_alg="HS256", iss="https://sso.qa.7pass.ctf.prosiebensat1.com", lifetime=3600 + ) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(code=code, id_token=_jws) - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_verify_id_token_c_hash_fail(): - code = 'AccessCode1' + code = "AccessCode1" lhsh = left_hash(code) - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - "c_hash": lhsh - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + "c_hash": lhsh, + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT( + kj, sign_alg="HS256", iss="https://sso.qa.7pass.ctf.prosiebensat1.com", lifetime=3600 + ) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(code="AccessCode289", id_token=_jws) with pytest.raises(CHashError): - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_verify_id_token_at_hash(): - token = 'AccessTokenWhichCouldBeASignedJWT' + token = "AccessTokenWhichCouldBeASignedJWT" lhsh = left_hash(token) - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - "at_hash": lhsh - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + "at_hash": lhsh, + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT( + kj, sign_alg="HS256", iss="https://sso.qa.7pass.ctf.prosiebensat1.com", lifetime=3600 + ) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(access_token=token, id_token=_jws) - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_verify_id_token_at_hash_fail(): - token = 'AccessTokenWhichCouldBeASignedJWT' - token2 = 'ACompletelyOtherAccessToken' + token = "AccessTokenWhichCouldBeASignedJWT" + token2 = "ACompletelyOtherAccessToken" lhsh = left_hash(token) - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - "at_hash": lhsh - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + "at_hash": lhsh, + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT( + kj, sign_alg="HS256", iss="https://sso.qa.7pass.ctf.prosiebensat1.com", lifetime=3600 + ) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(access_token=token2, id_token=_jws) with pytest.raises(AtHashError): - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_verify_id_token_missing_at_hash(): - token = 'AccessTokenWhichCouldBeASignedJWT' - - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - }) + token = "AccessTokenWhichCouldBeASignedJWT" + + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT( + kj, sign_alg="HS256", iss="https://sso.qa.7pass.ctf.prosiebensat1.com", lifetime=3600 + ) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(access_token=token, id_token=_jws) with pytest.raises(MissingRequiredAttribute): - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_verify_id_token_missing_c_hash(): - code = 'AccessCode1' - - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - }) + code = "AccessCode1" + + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT( + kj, sign_alg="HS256", iss="https://sso.qa.7pass.ctf.prosiebensat1.com", lifetime=3600 + ) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(code=code, id_token=_jws) with pytest.raises(MissingRequiredAttribute): - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_verify_id_token_at_hash_and_chash(): - token = 'AccessTokenWhichCouldBeASignedJWT' + token = "AccessTokenWhichCouldBeASignedJWT" at_hash = left_hash(token) - code = 'AccessCode1' + code = "AccessCode1" c_hash = left_hash(code) - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - "at_hash": at_hash, - 'c_hash': c_hash - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + "at_hash": at_hash, + "c_hash": c_hash, + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT( + kj, sign_alg="HS256", iss="https://sso.qa.7pass.ctf.prosiebensat1.com", lifetime=3600 + ) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(access_token=token, id_token=_jws, code=code) - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_verify_id_token_missing_iss(): - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', lifetime=3600) + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT(kj, sign_alg="HS256", lifetime=3600) _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(id_token=_jws) with pytest.raises(MissingRequiredAttribute): - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_verify_id_token_iss_not_in_keyjar(): - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com", - 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', lifetime=3600, - iss='https://example.com/op') + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric( + "https://sso.qa.7pass.ctf.prosiebensat1.com", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"] + ) + packer = JWT(kj, sign_alg="HS256", lifetime=3600, iss="https://example.com/op") _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(id_token=_jws) with pytest.raises(ValueError): - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://sso.qa.7pass.ctf.prosiebensat1.com", - client_id="554295ce3770612820620000") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://sso.qa.7pass.ctf.prosiebensat1.com", + client_id="554295ce3770612820620000", + ) def test_wrong_sign_alg(): - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("554295ce3770612820620000", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', lifetime=3600, iss='https://example.com/op') + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric("554295ce3770612820620000", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + packer = JWT(kj, sign_alg="HS256", lifetime=3600, iss="https://example.com/op") _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(id_token=_jws) with pytest.raises(UnsupportedAlgorithm): - verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://example.com/op", - client_id="554295ce3770612820620000", - allowed_sign_alg="RS256") + verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://example.com/op", + client_id="554295ce3770612820620000", + allowed_sign_alg="RS256", + ) def test_correct_sign_alg(): - idt = IdToken(**{ - "sub": "553df2bcf909104751cfd8b2", - "aud": [ - "5542958437706128204e0000", - "554295ce3770612820620000" - ], - "auth_time": 1441364872, - "azp": "554295ce3770612820620000", - }) + idt = IdToken( + **{ + "sub": "553df2bcf909104751cfd8b2", + "aud": ["5542958437706128204e0000", "554295ce3770612820620000"], + "auth_time": 1441364872, + "azp": "554295ce3770612820620000", + } + ) kj = KeyJar() - kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("https://example.com/op", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - kj.add_symmetric("554295ce3770612820620000", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig']) - packer = JWT(kj, sign_alg='HS256', lifetime=3600, iss='https://example.com/op') + kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric("https://example.com/op", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + kj.add_symmetric("554295ce3770612820620000", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"]) + packer = JWT(kj, sign_alg="HS256", lifetime=3600, iss="https://example.com/op") _jws = packer.pack(payload=idt.to_dict()) msg = AuthorizationResponse(id_token=_jws) - assert verify_id_token(msg, check_hash=True, keyjar=kj, - iss="https://example.com/op", - client_id="554295ce3770612820620000", - allowed_sign_alg="HS256") + assert verify_id_token( + msg, + check_hash=True, + keyjar=kj, + iss="https://example.com/op", + client_id="554295ce3770612820620000", + allowed_sign_alg="HS256", + ) diff --git a/tests/test_07_session.py b/tests/test_07_session.py index dc41fb2..d7bc53b 100644 --- a/tests/test_07_session.py +++ b/tests/test_07_session.py @@ -23,23 +23,26 @@ from oidcmsg.time_util import utc_time_sans_frac CLIENT_ID = "client_1" -ISS = 'https://example.com' - -IDTOKEN = IdToken(iss=ISS, sub="sub", - aud=CLIENT_ID, exp=utc_time_sans_frac() + 300, - nonce="N0nce", iat=time.time()) +ISS = "https://example.com" + +IDTOKEN = IdToken( + iss=ISS, + sub="sub", + aud=CLIENT_ID, + exp=utc_time_sans_frac() + 300, + nonce="N0nce", + iat=time.time(), +) KC_SYM_S = KeyBundle( - { - "kty": "oct", "key": "abcdefghijklmnop".encode("utf-8"), "use": "sig", - "alg": "HS256" - }) + {"kty": "oct", "key": "abcdefghijklmnop".encode("utf-8"), "use": "sig", "alg": "HS256"} +) NOW = utc_time_sans_frac() KEYDEF = [ {"type": "EC", "crv": "P-256", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["enc"]} - ] + {"type": "EC", "crv": "P-256", "use": ["enc"]}, +] def full_path(local_file): @@ -47,16 +50,22 @@ def full_path(local_file): return os.path.join(_dirname, local_file) -CLI_KEY = init_key_jar(public_path=full_path('pub_client.jwks'), - private_path=full_path('priv_client.jwks'), - key_defs=KEYDEF, issuer_id=CLIENT_ID) +CLI_KEY = init_key_jar( + public_path=full_path("pub_client.jwks"), + private_path=full_path("priv_client.jwks"), + key_defs=KEYDEF, + issuer_id=CLIENT_ID, +) -ISS_KEY = init_key_jar(public_path=full_path('pub_iss.jwks'), - private_path=full_path('priv_iss.jwks'), - key_defs=KEYDEF, issuer_id=ISS) +ISS_KEY = init_key_jar( + public_path=full_path("pub_iss.jwks"), + private_path=full_path("priv_iss.jwks"), + key_defs=KEYDEF, + issuer_id=ISS, +) -ISS_KEY.import_jwks_as_json(open(full_path('pub_client.jwks')).read(), CLIENT_ID) -CLI_KEY.import_jwks_as_json(open(full_path('pub_iss.jwks')).read(),ISS) +ISS_KEY.import_jwks_as_json(open(full_path("pub_client.jwks")).read(), CLIENT_ID) +CLI_KEY.import_jwks_as_json(open(full_path("pub_iss.jwks")).read(), ISS) class TestEndSessionResponse(object): @@ -68,43 +77,45 @@ class TestEndSessionRequest(object): def test_example(self): _symkey = KC_SYM_S.get(alg2keytype("HS256")) esreq = EndSessionRequest( - id_token_hint=IDTOKEN.to_jwt(key=_symkey, algorithm="HS256", - lifetime=300), + id_token_hint=IDTOKEN.to_jwt(key=_symkey, algorithm="HS256", lifetime=300), redirect_url="http://example.org/jqauthz", - state="state0") + state="state0", + ) request = EndSessionRequest().from_urlencoded(esreq.to_urlencoded()) keyjar = KeyJar() for _key in _symkey: - keyjar.add_symmetric('', _key.key) + keyjar.add_symmetric("", _key.key) keyjar.add_symmetric(ISS, _key.key) keyjar.add_symmetric(CLIENT_ID, _key.key) request.verify(keyjar=keyjar) assert isinstance(request, EndSessionRequest) - assert set(request.keys()) == {verified_claim_name('id_token_hint'), - 'id_token_hint', 'redirect_url', 'state'} + assert set(request.keys()) == { + verified_claim_name("id_token_hint"), + "id_token_hint", + "redirect_url", + "state", + } assert request["state"] == "state0" - assert request[ - verified_claim_name("id_token_hint")]["aud"] == ["client_1"] + assert request[verified_claim_name("id_token_hint")]["aud"] == ["client_1"] class TestCheckSessionRequest(object): def test_example(self): _symkey = KC_SYM_S.get(alg2keytype("HS256")) csr = CheckSessionRequest( - id_token=IDTOKEN.to_jwt(key=_symkey, algorithm="HS256", - lifetime=300)) + id_token=IDTOKEN.to_jwt(key=_symkey, algorithm="HS256", lifetime=300) + ) keyjar = KeyJar() - keyjar.add_kb('', KC_SYM_S) + keyjar.add_kb("", KC_SYM_S) with pytest.raises(ValueError): assert csr.verify(keyjar=keyjar) - def test_example_1(self): _symkey = KC_SYM_S.get(alg2keytype("HS256")) csr = CheckSessionRequest( - id_token=IDTOKEN.to_jwt(key=_symkey, algorithm="HS256", - lifetime=300)) + id_token=IDTOKEN.to_jwt(key=_symkey, algorithm="HS256", lifetime=300) + ) keyjar = KeyJar() keyjar.add_kb(ISS, KC_SYM_S) assert csr.verify(keyjar=keyjar) @@ -117,12 +128,12 @@ def test_example(self): "nickname": None, "email": {"essential": True}, "verified": {"essential": True}, - "picture": None - } + "picture": None, + } - cr = ClaimsRequest(userinfo=Claims(**claims), - id_token=Claims(auth_time=None, - acr={"values": ["2"]})) + cr = ClaimsRequest( + userinfo=Claims(**claims), id_token=Claims(auth_time=None, acr={"values": ["2"]}) + ) cr.verify() _url = cr.to_urlencoded() cr1 = ClaimsRequest().from_urlencoded(_url) @@ -141,10 +152,8 @@ def test_logout_token_1(): "iat": NOW, "jti": "bWJq", "sid": "08a5019c-17e1-4977-8f42-65a12843ea02", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}, + } lt = LogoutToken(**val) assert lt.verify() @@ -156,10 +165,8 @@ def test_logout_token_2(): "aud": [CLIENT_ID], "iat": NOW, "jti": "bWJq", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}, + } lt = LogoutToken(**val) assert lt.verify() @@ -171,10 +178,8 @@ def test_logout_token_3(): "iat": NOW, "jti": "bWJq", "sid": "08a5019c-17e1-4977-8f42-65a12843ea02", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}, + } lt = LogoutToken(**val) assert lt.verify() @@ -185,10 +190,8 @@ def test_logout_token_4(): "aud": [CLIENT_ID], "iat": NOW, "jti": "bWJq", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}, + } lt = LogoutToken(**val) with pytest.raises(ValueError): lt.verify() @@ -200,10 +203,8 @@ def test_logout_token_5(): "aud": [CLIENT_ID], "iat": NOW, "jti": "bWJq", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {'foo':'bar'} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {"foo": "bar"}}, + } lt = LogoutToken(**val) with pytest.raises(ValueError): lt.verify() @@ -215,10 +216,8 @@ def test_logout_token_6(): "aud": [CLIENT_ID], "iat": NOW, "jti": "bWJq", - "events": { - "http://schemas.openid.net/event/foo": {} - } - } + "events": {"http://schemas.openid.net/event/foo": {}}, + } lt = LogoutToken(**val) with pytest.raises(ValueError): lt.verify() @@ -230,11 +229,8 @@ def test_logout_token_7(): "aud": [CLIENT_ID], "iat": NOW, "jti": "bWJq", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {}, - "http://schemas.openid.net/event/foo": {} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}, "http://schemas.openid.net/event/foo": {}}, + } lt = LogoutToken(**val) with pytest.raises(ValueError): lt.verify() @@ -247,11 +243,9 @@ def test_logout_token_with_nonce(): "iat": NOW, "jti": "bWJq", "sid": "08a5019c-17e1-4977-8f42-65a12843ea02", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {} - }, - "nonce": "1234567890" - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}, + "nonce": "1234567890", + } lt = LogoutToken(**val) with pytest.raises(MessageException): lt.verify() @@ -261,13 +255,11 @@ def test_logout_token_wrong_iat(): val = { "iss": ISS, "aud": [CLIENT_ID], - "iat": NOW+10, + "iat": NOW + 10, "jti": "bWJq", "sid": "08a5019c-17e1-4977-8f42-65a12843ea02", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}, + } lt = LogoutToken(**val) with pytest.raises(ValueError): lt.verify() @@ -283,13 +275,11 @@ def test_logout_token_wrong_aud(): "iat": NOW, "jti": "bWJq", "sid": "08a5019c-17e1-4977-8f42-65a12843ea02", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}, + } lt = LogoutToken(**val) with pytest.raises(NotForMe): - lt.verify(aud='deep_purple') + lt.verify(aud="deep_purple") lt.verify(aud=CLIENT_ID) @@ -301,13 +291,11 @@ def test_logout_token_wrong_iss(): "iat": NOW, "jti": "bWJq", "sid": "08a5019c-17e1-4977-8f42-65a12843ea02", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}, + } lt = LogoutToken(**val) with pytest.raises(NotForMe): - lt.verify(iss='deep_purple') + lt.verify(iss="deep_purple") lt.verify(iss=ISS) @@ -319,12 +307,10 @@ def test_back_channel_logout_request(): "iat": NOW, "jti": "bWJq", "sid": "08a5019c-17e1-4977-8f42-65a12843ea02", - "events": { - BACK_CHANNEL_LOGOUT_EVENT: {} - } - } + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}, + } lt = LogoutToken(**val) - signer = JWS(lt.to_json(), alg='ES256') + signer = JWS(lt.to_json(), alg="ES256") _jws = signer.sign_compact(keys=ISS_KEY.get_signing_key(issuer_id=ISS)) bclr = BackChannelLogoutRequest(logout_token=_jws) @@ -334,9 +320,9 @@ def test_back_channel_logout_request(): _request = BackChannelLogoutRequest().from_urlencoded(_req) - assert 'logout_token' in _request + assert "logout_token" in _request _verified = _request.verify(keyjar=CLI_KEY, iss=ISS, aud=CLIENT_ID, skew=30) assert _verified - assert set(_request.keys()) == {'logout_token', '__verified_logout_token'} + assert set(_request.keys()) == {"logout_token", "__verified_logout_token"} diff --git a/tests/test_10_identity_assurance.py b/tests/test_10_identity_assurance.py index b0a05ed..ccc9257 100644 --- a/tests/test_10_identity_assurance.py +++ b/tests/test_10_identity_assurance.py @@ -35,22 +35,18 @@ def test_verification_element(): assert ve - s = '2020-01-11T11:00:00+0100' + s = "2020-01-11T11:00:00+0100" ve_2 = VerificationElement(trust_framework="TrustAreUs") ve_2["time"] = s - assert quote_plus('2020-01-11T11:00:00+0100') in ve_2.to_urlencoded() + assert quote_plus("2020-01-11T11:00:00+0100") in ve_2.to_urlencoded() def test_verified_claims(): s = { "userinfo": { "verified_claims": { - "claims": { - "given_name": None, - "family_name": None, - "birthdate": None - } + "claims": {"given_name": None, "family_name": None, "birthdate": None} } } } @@ -61,25 +57,20 @@ def test_verified_claims(): def test_verfication_element_from_dict(): d = { - "verification": { - "trust_framework": "eidas_ial_substantial" - }, + "verification": {"trust_framework": "eidas_ial_substantial"}, "claims": { "given_name": "Max", "family_name": "Meier", "birthdate": "1956-01-28", - "place_of_birth": { - "country": "DE", - "locality": "Musterstadt" - }, + "place_of_birth": {"country": "DE", "locality": "Musterstadt"}, "nationality": "DE", "address": { "locality": "Maxstadt", "postal_code": "12344", "country": "DE", - "street_address": "An der Sanddüne 22" - } - } + "street_address": "An der Sanddüne 22", + }, + }, } v = VerifiedClaims(**d) assert v @@ -101,23 +92,16 @@ def test_userinfo_response(): "method": "pipp", "document": { "type": "idcard", - "issuer": { - "name": "Stadt Augsburg", - "country": "DE" - }, + "issuer": {"name": "Stadt Augsburg", "country": "DE"}, "number": "53554554", "date_of_issuance": "2012-04-23", - "date_of_expiry": "2022-04-22" - } + "date_of_expiry": "2022-04-22", + }, } - ] + ], }, - "claims": { - "given_name": "Max", - "family_name": "Meier", - "birthdate": "1956-01-28" - } - } + "claims": {"given_name": "Max", "family_name": "Meier", "birthdate": "1956-01-28"}, + }, } v = VerifiedClaims(**resp["verified_claims"]) @@ -138,11 +122,7 @@ def test_userinfo_claims_request_5_1_1(): userinfo_claims = { "userinfo": { "verified_claims": { - "claims": { - "given_name": None, - "family_name": None, - "birthdate": None - } + "claims": {"given_name": None, "family_name": None, "birthdate": None} } } } @@ -159,7 +139,7 @@ def test_userinfo_claims_request_5_1_2(): "claims": { "given_name": {"essential": True}, "family_name": {"essential": True}, - "birthdate": None + "birthdate": None, } } } @@ -177,14 +157,10 @@ def test_userinfo_claims_request_5_1_3(): "claims": { "given_name": { "essential": True, - "purpose": "To make communication look more personal" + "purpose": "To make communication look more personal", }, - "family_name": { - "essential": True - }, - "birthdate": { - "purpose": "To send you best wishes on your birthday" - } + "family_name": {"essential": True}, + "birthdate": {"purpose": "To send you best wishes on your birthday"}, } } } @@ -196,13 +172,7 @@ def test_userinfo_claims_request_5_1_3(): def test_userinfo_claims_request_5_1_4(): - userinfo_claims = { - "userinfo": { - "verified_claims": { - "claims": None - } - } - } + userinfo_claims = {"userinfo": {"verified_claims": {"claims": None}}} icr = IDAClaimsRequest(**userinfo_claims["userinfo"]) icr.verify() @@ -211,13 +181,7 @@ def test_userinfo_claims_request_5_1_4(): def test_userinfo_claims_request_5_2_1(): verified_claims = { - "verified_claims": { - "verification": { - "time": None, - "evidence": None - }, - "claims": None - } + "verified_claims": {"verification": {"time": None, "evidence": None}, "claims": None} } icr = IDAClaimsRequest(**verified_claims) @@ -228,16 +192,8 @@ def test_userinfo_claims_request_5_2_1(): def test_userinfo_claims_request_5_2_2(): verified_claims = { "verified_claims": { - "verification": { - "time": None, - "evidence": [ - { - "method": None, - "document": None - } - ] - }, - "claims": None + "verification": {"time": None, "evidence": [{"method": None, "document": None}]}, + "claims": None, } } @@ -254,15 +210,11 @@ def test_userinfo_claims_request_5_2_3(): "evidence": [ { "method": None, - "document": { - "issuer": None, - "number": None, - "date_of_issuance": None - } + "document": {"issuer": None, "number": None, "date_of_issuance": None}, } - ] + ], }, - "claims": None + "claims": None, } } @@ -276,29 +228,16 @@ def test_userinfo_claims_request_5_3_1(): "userinfo": { "verified_claims": { "verification": { - "trust_framework": { - "value": "de_aml" - }, + "trust_framework": {"value": "de_aml"}, "evidence": [ { - "type": { - "value": "id_document" - }, - "method": { - "value": "pipp" - }, - "document": { - "type": { - "values": [ - "idcard", - "passport" - ] - } - } + "type": {"value": "id_document"}, + "method": {"value": "pipp"}, + "document": {"type": {"values": ["idcard", "passport"]}}, } - ] + ], }, - "claims": None + "claims": None, } } } @@ -311,14 +250,7 @@ def test_userinfo_claims_request_5_3_1(): def test_userinfo_claims_request_5_3_2(): userinfo_claims = { "userinfo": { - "verified_claims": { - "verification": { - "date": { - "max_age": 63113852 - } - }, - "claims": None - } + "verified_claims": {"verification": {"date": {"max_age": 63113852}}, "claims": None} } } @@ -340,33 +272,27 @@ def test_example_6_1(): "method": "pipp", "document": { "type": "idcard", - "issuer": { - "name": "Stadt Augsburg", - "country": "DE" - }, + "issuer": {"name": "Stadt Augsburg", "country": "DE"}, "number": "53554554", "date_of_issuance": "2012-04-23", - "date_of_expiry": "2022-04-22" - } + "date_of_expiry": "2022-04-22", + }, } - ] + ], }, "claims": { "given_name": "Max", "family_name": "Meier", "birthdate": "1956-01-28", - "place_of_birth": { - "country": "DE", - "locality": "Musterstadt" - }, + "place_of_birth": {"country": "DE", "locality": "Musterstadt"}, "nationality": "DE", "address": { "locality": "Maxstadt", "postal_code": "12344", "country": "DE", - "street_address": "An der Sanddüne 22" - } - } + "street_address": "An der Sanddüne 22", + }, + }, } } @@ -389,14 +315,11 @@ def test_example_6_2(): "method": "pipp", "document": { "document_type": "de_erp_replacement_idcard", - "issuer": { - "name": "Stadt Augsburg", - "country": "DE" - }, + "issuer": {"name": "Stadt Augsburg", "country": "DE"}, "number": "53554554", "date_of_issuance": "2012-04-23", - "date_of_expiry": "2022-04-22" - } + "date_of_expiry": "2022-04-22", + }, }, { "type": "utility_bill", @@ -404,28 +327,25 @@ def test_example_6_2(): "name": "Stadtwerke Musterstadt", "country": "DE", "region": "Thüringen", - "street_address": "Energiestrasse 33" + "street_address": "Energiestrasse 33", }, - "date": "2013-01-31" - } - ] + "date": "2013-01-31", + }, + ], }, "claims": { "given_name": "Max", "family_name": "Meier", "birthdate": "1956-01-28", - "place_of_birth": { - "country": "DE", - "locality": "Musterstadt" - }, + "place_of_birth": {"country": "DE", "locality": "Musterstadt"}, "nationality": "DE", "address": { "locality": "Maxstadt", "postal_code": "12344", "country": "DE", - "street_address": "An der Sanddüne 22" - } - } + "street_address": "An der Sanddüne 22", + }, + }, } } @@ -441,25 +361,20 @@ def test_example_6_2(): def test_example_6_3(): verified_claims = { "verified_claims": { - "verification": { - "trust_framework": "eidas_ial_substantial" - }, + "verification": {"trust_framework": "eidas_ial_substantial"}, "claims": { "given_name": "Max", "family_name": "Meier", "birthdate": "1956-01-28", - "place_of_birth": { - "country": "DE", - "locality": "Musterstadt" - }, + "place_of_birth": {"country": "DE", "locality": "Musterstadt"}, "nationality": "DE", "address": { "locality": "Maxstadt", "postal_code": "12344", "country": "DE", - "street_address": "An der Sanddüne 22" - } - } + "street_address": "An der Sanddüne 22", + }, + }, } } @@ -484,23 +399,16 @@ def test_example_6_4_2(): "method": "pipp", "document": { "type": "idcard", - "issuer": { - "name": "Stadt Augsburg", - "country": "DE" - }, + "issuer": {"name": "Stadt Augsburg", "country": "DE"}, "number": "53554554", "date_of_issuance": "2012-04-23", - "date_of_expiry": "2022-04-22" - } + "date_of_expiry": "2022-04-22", + }, } - ] + ], }, - "claims": { - "given_name": "Max", - "family_name": "Meier", - "birthdate": "1956-01-28" - } - } + "claims": {"given_name": "Max", "family_name": "Meier", "birthdate": "1956-01-28"}, + }, } vc = VerifiedClaims(**userinfo_response["verified_claims"]) @@ -518,4 +426,4 @@ def test_construct_5_2_1(): verified_claims["claims"] = None _val = verified_claims.to_json() - assert _val == '{"verification": {"time": null, "evidence": null}, "claims": null}' \ No newline at end of file + assert _val == '{"verification": {"time": null, "evidence": null}, "claims": null}' diff --git a/tests/test_11_impexp.py b/tests/test_11_impexp.py index f838614..cc81f7c 100644 --- a/tests/test_11_impexp.py +++ b/tests/test_11_impexp.py @@ -19,7 +19,7 @@ class ImpExpTest(ImpExp): "message": AuthorizationRequest, "response_class": object, "key_bundle": KeyBundle, - "bundles": [KeyBundle] + "bundles": [KeyBundle], } @@ -28,8 +28,12 @@ def test_dump_load(): b.string = "foo" b.list = ["a", "b", "c"] b.dict = {"a": 1, "b": 2} - b.message = AuthorizationRequest(scope="openid", redirect_uri="https://example.com/cb", - response_type="code", client_id="abcdefg") + b.message = AuthorizationRequest( + scope="openid", + redirect_uri="https://example.com/cb", + response_type="code", + client_id="abcdefg", + ) b.response_class = AuthorizationResponse b.key_bundle = build_key_bundle(key_conf=KEYSPEC) b.bundles = [build_key_bundle(key_conf=KEYSPEC)] @@ -56,8 +60,12 @@ def test_flush(): b.string = "foo" b.list = ["a", "b", "c"] b.dict = {"a": 1, "b": 2} - b.message = AuthorizationRequest(scope="openid", redirect_uri="https://example.com/cb", - response_type="code", client_id="abcdefg") + b.message = AuthorizationRequest( + scope="openid", + redirect_uri="https://example.com/cb", + response_type="code", + client_id="abcdefg", + ) b.response_class = AuthorizationResponse b.key_bundle = build_key_bundle(key_conf=KEYSPEC) b.bundles = [build_key_bundle(key_conf=KEYSPEC)] diff --git a/tests/test_12_context.py b/tests/test_12_context.py index 12d1de8..e34b0be 100644 --- a/tests/test_12_context.py +++ b/tests/test_12_context.py @@ -7,25 +7,25 @@ KEYDEF = [ {"type": "EC", "crv": "P-256", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["enc"]} + {"type": "EC", "crv": "P-256", "use": ["enc"]}, ] JWKS = { "keys": [ { - "n": - 'zkpUgEgXICI54blf6iWiD2RbMDCOO1jV0VSff1MFFnujM4othfMsad7H1kRo50YM5S' - '_X9TdvrpdOfpz5aBaKFhT6Ziv0nhtcekq1eRl8mjBlvGKCE5XGk-0LFSDwvqgkJoFY' - 'Inq7bu0a4JEzKs5AyJY75YlGh879k1Uu2Sv3ZZOunfV1O1Orta-NvS-aG_jN5cstVb' - 'CGWE20H0vFVrJKNx0Zf-u-aA-syM4uX7wdWgQ-owoEMHge0GmGgzso2lwOYf_4znan' - 'LwEuO3p5aabEaFoKNR4K6GjQcjBcYmDEE4CtfRU9AEmhcD1kleiTB9TjPWkgDmT9MX' - 'sGxBHf3AKT5w', - "e": "AQAB", "kty": "RSA", "kid": "rsa1" + "n": "zkpUgEgXICI54blf6iWiD2RbMDCOO1jV0VSff1MFFnujM4othfMsad7H1kRo50YM5S" + "_X9TdvrpdOfpz5aBaKFhT6Ziv0nhtcekq1eRl8mjBlvGKCE5XGk-0LFSDwvqgkJoFY" + "Inq7bu0a4JEzKs5AyJY75YlGh879k1Uu2Sv3ZZOunfV1O1Orta-NvS-aG_jN5cstVb" + "CGWE20H0vFVrJKNx0Zf-u-aA-syM4uX7wdWgQ-owoEMHge0GmGgzso2lwOYf_4znan" + "LwEuO3p5aabEaFoKNR4K6GjQcjBcYmDEE4CtfRU9AEmhcD1kleiTB9TjPWkgDmT9MX" + "sGxBHf3AKT5w", + "e": "AQAB", + "kty": "RSA", + "kid": "rsa1", }, { - "k": - 'YTEyZjBlMDgxMGI4YWU4Y2JjZDFiYTFlZTBjYzljNDU3YWM0ZWNiNzhmNmFlYTNkNTY0NzMzYjE', - "kty": "oct" + "k": "YTEyZjBlMDgxMGI4YWU4Y2JjZDFiYTFlZTBjYzljNDU3YWM0ZWNiNzhmNmFlYTNkNTY0NzMzYjE", + "kty": "oct", }, ] } @@ -42,51 +42,48 @@ def test_dump_load(): class TestDumpLoad(object): @pytest.fixture(autouse=True) def setup(self): - self.conf = { - 'issuer': 'https://example.com' - } + self.conf = {"issuer": "https://example.com"} def test_context_with_entity_id_no_keys(self): - c = OidcContext(self.conf, entity_id='https://example.com') + c = OidcContext(self.conf, entity_id="https://example.com") mem = c.dump() c2 = OidcContext().load(mem) assert c2.keyjar.owners() == [] - assert c2.issuer == 'https://example.com' + assert c2.issuer == "https://example.com" def test_context_with_entity_id_and_keys(self): conf = copy.deepcopy(self.conf) - conf['keys'] = {'key_defs': KEYDEF} - c = OidcContext(conf, entity_id='https://example.com') + conf["keys"] = {"key_defs": KEYDEF} + c = OidcContext(conf, entity_id="https://example.com") mem = c.dump() c2 = OidcContext().load(mem) - assert set(c2.keyjar.owners()) == {'', 'https://example.com'} + assert set(c2.keyjar.owners()) == {"", "https://example.com"} def test_context_with_entity_id_and_jwks(self): conf = copy.deepcopy(self.conf) - conf['jwks'] = JWKS - c = OidcContext(conf, entity_id='https://example.com') + conf["jwks"] = JWKS + c = OidcContext(conf, entity_id="https://example.com") mem = c.dump() c2 = OidcContext().load(mem) - assert set(c2.keyjar.owners()) == {'', 'https://example.com'} - assert len(c2.keyjar.get('sig', 'RSA')) == 1 - assert len(c2.keyjar.get('sig', 'RSA', issuer_id='https://example.com')) == 1 - assert len(c2.keyjar.get('sig', 'oct')) == 1 - assert len(c2.keyjar.get('sig', 'oct', issuer_id='https://example.com')) == 1 + assert set(c2.keyjar.owners()) == {"", "https://example.com"} + assert len(c2.keyjar.get("sig", "RSA")) == 1 + assert len(c2.keyjar.get("sig", "RSA", issuer_id="https://example.com")) == 1 + assert len(c2.keyjar.get("sig", "oct")) == 1 + assert len(c2.keyjar.get("sig", "oct", issuer_id="https://example.com")) == 1 def test_context_restore(self): conf = copy.deepcopy(self.conf) - conf['keys'] = {'key_defs': KEYDEF} + conf["keys"] = {"key_defs": KEYDEF} - c = OidcContext(conf, entity_id='https://example.com') + c = OidcContext(conf, entity_id="https://example.com") mem = c.dump() c2 = OidcContext().load(mem) - assert set(c2.keyjar.owners()) == {'', 'https://example.com'} - assert len(c2.keyjar.get('sig', 'EC')) == 1 - assert len(c2.keyjar.get('enc', 'EC')) == 1 - assert len(c.keyjar.get('sig', 'RSA')) == 0 - assert len(c.keyjar.get('sig', 'oct')) == 0 - + assert set(c2.keyjar.owners()) == {"", "https://example.com"} + assert len(c2.keyjar.get("sig", "EC")) == 1 + assert len(c2.keyjar.get("enc", "EC")) == 1 + assert len(c.keyjar.get("sig", "RSA")) == 0 + assert len(c.keyjar.get("sig", "oct")) == 0 diff --git a/tests/test_13_dump_item.py b/tests/test_13_dump_item.py new file mode 100644 index 0000000..a026c29 --- /dev/null +++ b/tests/test_13_dump_item.py @@ -0,0 +1,33 @@ +from cryptojwt import KeyBundle +from cryptojwt.key_bundle import build_key_bundle +from cryptojwt.utils import qualified_name + +from oidcmsg.item import DLDict + +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +KEYSPEC_2 = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, + {"type": "EC", "crv": "P-384", "use": ["sig"]}, +] + + +def test_dl_dict(): + _dict = DLDict() + _kb1 = build_key_bundle(key_conf=KEYSPEC) + _dict["a"] = _kb1 + _kb2 = build_key_bundle(key_conf=KEYSPEC_2) + _dict["b"] = _kb2 + + dump = _dict.dump() + + _dict_copy = DLDict().load(dump) + + assert set(_dict_copy.keys()) == {"a", "b"} + + kb1_copy = _dict_copy["a"] + assert len(kb1_copy.keys()) == 2