diff --git a/fido2/ctap2/extensions.py b/fido2/ctap2/extensions.py index 74583f9..c841e32 100644 --- a/fido2/ctap2/extensions.py +++ b/fido2/ctap2/extensions.py @@ -708,50 +708,59 @@ class SignExtension(Ctap2Extension): NAME = "sign" - def process_create_input(self, inputs): + def make_credential(self, ctap, options): + inputs = options.extensions or {} data = _SignInputs.from_dict(inputs.get("sign")) - if not data or not self.is_supported(): + if not data or not self.is_supported(ctap): return if data.sign or not data.generate_key: raise ValueError("Invalid inputs") - gk = data.generate_key - - selection = ( - self._create_options.authenticator_selection - or AuthenticatorSelectionCriteria() - ) - flags = ( - 0b101 - if selection.user_verification == UserVerificationRequirement.REQUIRED - else 0b001 - ) - outputs = {3: gk.algorithms, 4: flags} + def prepare_inputs(_): + gk = data.generate_key - if gk.ph_data: - outputs[0] = gk.ph_data + selection = ( + options.authenticator_selection or AuthenticatorSelectionCriteria() + ) + flags = ( + 0b101 + if selection.user_verification == UserVerificationRequirement.REQUIRED + else 0b001 + ) + outputs = {3: gk.algorithms, 4: flags} - return outputs + if gk.ph_data: + outputs[0] = gk.ph_data - def process_create_output(self, attestation_response, *args): - data = attestation_response.auth_data.extensions.get(self.NAME) - att_obj = AttestationResponse.from_dict(cbor.decode(data[7])) # type: ignore - cred_data = att_obj.auth_data.credential_data - assert cred_data is not None # nosec - pk = cred_data.public_key + return {self.NAME: outputs} - return { - "sign": _SignOutputs( - generated_key=_SignGeneratedKey( - public_key=cbor.encode(pk), - key_handle=cbor.encode(pk.get_ref()), - ), - signature=data.get(6), + def prepare_outputs(response, *args): + data = response.auth_data.extensions.get(self.NAME) + att_obj = AttestationResponse.from_dict( + cbor.decode(data[7]) # type: ignore ) - } + cred_data = att_obj.auth_data.credential_data + assert cred_data is not None # nosec + pk = cred_data.public_key + + return { + self.NAME: _SignOutputs( + generated_key=_SignGeneratedKey( + public_key=cbor.encode(pk), + key_handle=cbor.encode(pk.get_ref()), + ), + signature=data.get(6), + ) + } - def process_get_input(self, inputs): + return ExtensionProcessor( + prepare_inputs=prepare_inputs, + prepare_outputs=prepare_outputs, + ) + + def get_assertion(self, ctap, options): + inputs = options.extensions or {} data = _SignInputs.from_dict(inputs.get("sign")) if not data or not self.is_supported(): return @@ -759,26 +768,31 @@ def process_get_input(self, inputs): if not data.sign or data.generate_key: raise ValueError("Invalid inputs") - sign = data.sign - by_creds = sign.key_handle_by_credential - - # Make sure all keys are valid IDs from allow_credentials - allow_list = self._get_options.allow_credentials - if not allow_list: - raise ValueError("sign requires allowCredentials") - ids = {websafe_encode(c.id) for c in allow_list} - if ids.difference(by_creds): - raise ValueError("keyHandleByCredential is not valid") - if not self._selected: - return - kh = by_creds[websafe_encode(self._selected.id)] - - return { - 0: sign.ph_data, - 5: [kh], - } + def prepare_inputs(selected): + sign = data.sign + by_creds = sign.key_handle_by_credential + + # Make sure all keys are valid IDs from allow_credentials + allow_list = options.allow_credentials + if not allow_list or not selected: + raise ValueError("sign requires allowCredentials") + ids = {websafe_encode(c.id) for c in allow_list} + if ids.difference(by_creds): + raise ValueError("keyHandleByCredential is not valid") + kh = by_creds[websafe_encode(selected.id)] + + return { + self.NAME: { + 0: sign.ph_data, + 5: [kh], + } + } - def process_get_output(self, assertion_response, *args): - data = assertion_response.auth_data.extensions.get(self.NAME) + def prepare_outputs(response, *args): + data = response.auth_data.extensions.get(self.NAME) + return {self.NAME: _SignOutputs(signature=data[6])} - return {"sign": _SignOutputs(signature=data[6])} + return ExtensionProcessor( + prepare_inputs=prepare_inputs, + prepare_outputs=prepare_outputs, + )