diff --git a/examples/hmac_secret.py b/examples/hmac_secret.py index 559d18f..dfcfb45 100644 --- a/examples/hmac_secret.py +++ b/examples/hmac_secret.py @@ -152,7 +152,7 @@ def request_uv(self, permissions, rd_id): # Only one cred in allowCredentials, only one response. result = result.get_response(0) -output1 = result.extension_results["hmacGetSecret"]["output1"] +output1 = result.extension_results.hmac_get_secret.output1 print("Authenticated, secret:", output1.hex()) # Authenticate again, using two salts to generate two secrets: @@ -173,6 +173,6 @@ def request_uv(self, permissions, rd_id): # Only one cred in allowCredentials, only one response. result = result.get_response(0) -output = result.extension_results["hmacGetSecret"] -print("Old secret:", output["output1"].hex()) -print("New secret:", output["output2"].hex()) +output = result.extension_results.hmac_get_secret +print("Old secret:", output.output1.hex()) +print("New secret:", output.output2.hex()) diff --git a/fido2/client.py b/fido2/client.py index 372925b..3b384d6 100644 --- a/fido2/client.py +++ b/fido2/client.py @@ -379,7 +379,7 @@ def do_make_credential( return AuthenticatorAttestationResponse( client_data, AttestationObject.create(att_obj.fmt, att_obj.auth_data, att_obj.att_stmt), - {}, + ClientExtensionOutputs({}), ) def do_get_assertion( @@ -1103,7 +1103,7 @@ def make_credential(self, options, **kwargs): logger.info("New credential registered") return AuthenticatorAttestationResponse( - client_data, AttestationObject(result), extensions + client_data, AttestationObject(result), ClientExtensionOutputs(extensions) ) def get_assertion(self, options, **kwargs): @@ -1151,5 +1151,5 @@ def get_assertion(self, options, **kwargs): user=user, ) ], - extensions, + ClientExtensionOutputs(extensions), ) diff --git a/fido2/ctap2/extensions.py b/fido2/ctap2/extensions.py index 1405245..b2df7ca 100644 --- a/fido2/ctap2/extensions.py +++ b/fido2/ctap2/extensions.py @@ -44,6 +44,31 @@ import warnings +class ClientExtensionOutputs(Mapping[str, Any]): + def __init__(self, outputs: Mapping[str, Any]): + self._members = {k: v for k, v in outputs.items() if v is not None} + + def __iter__(self): + return iter(self._members) + + def __len__(self): + return len(self._members) + + def __getitem__(self, key): + value = self._members[key] + if isinstance(value, bytes): + return websafe_encode(value) + return dict(value) if isinstance(value, Mapping) else value + + def __getattr__(self, key): + parts = key.split("_") + name = parts[0] + "".join(p.title() for p in parts[1:]) + return self._members.get(name) + + def __repr__(self): + return repr(dict(self)) + + class Ctap2Extension(abc.ABC): """Base class for Ctap2 extensions. Subclasses are instantiated for a single request, if the Authenticator supports @@ -143,7 +168,7 @@ class _PrfValues(_JsonDataObject): @dataclass(eq=False, frozen=True) class _PrfInputs(_JsonDataObject): eval: Optional[_PrfValues] = None - evalByCredential: Optional[Mapping[str, _PrfValues]] = None + eval_by_credential: Optional[Mapping[str, _PrfValues]] = None @dataclass(eq=False, frozen=True) @@ -186,11 +211,10 @@ def process_get_input(self, inputs): if not self.is_supported(): return - data = inputs.get("prf") - if data: - prf = _PrfInputs.from_dict(data) + prf = _PrfInputs.from_dict(inputs.get("prf")) + if prf: secrets = prf.eval - by_creds = prf.evalByCredential + by_creds = prf.eval_by_credential if by_creds: # Make sure all keys are valid IDs from allow_credentials allow_list = self._get_options.allow_credentials @@ -213,11 +237,10 @@ def process_get_input(self, inputs): ) self.prf = True else: - data = inputs.get("hmacGetSecret") - if not data or not self._allow_hmac_secret: + get_secret = _HmacGetSecretInput.from_dict(inputs.get("hmacGetSecret")) + if not get_secret or not self._allow_hmac_secret: return - res = _HmacGetSecretInput.from_dict(data) - salts = res.salt1, res.salt2 or b"" + salts = get_secret.salt1, get_secret.salt2 or b"" self.prf = False if not ( @@ -279,7 +302,7 @@ def is_supported(self): return super().is_supported() and self.ctap.info.options.get("largeBlobs") def process_create_input(self, inputs): - data = _LargeBlobInputs.from_dict(inputs.get("largeBlob", {})) + data = _LargeBlobInputs.from_dict(inputs.get("largeBlob")) if data: if data.read or data.write: raise ValueError("Invalid set of parameters") @@ -295,12 +318,13 @@ def process_create_output(self, attestation_response, *args): } def get_get_permissions(self, inputs): - if _LargeBlobInputs.from_dict(inputs.get("largeBlob", {})).write: + data = _LargeBlobInputs.from_dict(inputs.get("largeBlob")) + if data and data.write: return ClientPin.PERMISSION.LARGE_BLOB_WRITE return ClientPin.PERMISSION(0) def process_get_input(self, inputs): - data = _LargeBlobInputs.from_dict(inputs.get("largeBlob", {})) + data = _LargeBlobInputs.from_dict(inputs.get("largeBlob")) if data: if data.support or (data.read and data.write): raise ValueError("Invalid set of parameters") @@ -310,7 +334,7 @@ def process_get_input(self, inputs): self._action = True else: self._action = data.write - return True if data else None + return True def process_get_output(self, assertion_response, token, pin_protocol): blob_key = assertion_response.large_blob_key @@ -413,24 +437,3 @@ def process_create_output(self, attestation_response, *args): ) rk = selection.require_resident_key return {"credProps": _CredPropsOutputs(rk=rk)} - - -class ClientExtensionOutputs(Mapping[str, Any]): - def __init__(self, outputs: Mapping[str, Any]): - self._members = {k: v for k, v in outputs.items() if v is not None} - - def __iter__(self): - return iter(self._members) - - def __len__(self): - return len(self._members) - - def __getitem__(self, key): - value = self._members[key] - return dict(value) if isinstance(value, Mapping) else value - - def __getattr__(self, key): - return self._members.get(key) - - def __repr__(self): - return repr(dict(self)) diff --git a/fido2/webauthn.py b/fido2/webauthn.py index 33eedc8..ba7a6fc 100644 --- a/fido2/webauthn.py +++ b/fido2/webauthn.py @@ -573,6 +573,13 @@ def from_dict(cls, data: Optional[Mapping[str, Any]]): data = value return super().from_dict(data) + @classmethod + def _parse_value(cls, t, value): + if t == Optional[Mapping[str, Any]]: + # Don't convert extension_results + return value + return super()._parse_value(t, value) + @dataclass(eq=False, frozen=True) class AuthenticatorAssertionResponse(_WebAuthnDataObject): @@ -596,6 +603,13 @@ def from_dict(cls, data: Optional[Mapping[str, Any]]): data = value return super().from_dict(data) + @classmethod + def _parse_value(cls, t, value): + if t == Optional[Mapping[str, Any]]: + # Don't convert extension_results + return value + return super()._parse_value(t, value) + @dataclass(eq=False, frozen=True) class RegistrationResponse(_WebAuthnDataObject):