diff --git a/tlslite/keyexchange.py b/tlslite/keyexchange.py index 8d6e32cf..107a75ed 100644 --- a/tlslite/keyexchange.py +++ b/tlslite/keyexchange.py @@ -720,6 +720,9 @@ def makeServerKeyExchange(self, sigHash=None): except StopIteration: raise TLSIllegalParameterException("No common EC point format") + ext_negotiated = 'uncompressed' if \ + ext_negotiated == ECPointFormat.uncompressed else 'compressed' + ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated) version = self.serverHello.server_version @@ -747,7 +750,12 @@ def processClientKeyExchange(self, clientKeyExchange): ] if not ext_supported: raise TLSIllegalParameterException("No common EC point format") - return kex.calc_shared_key(self.ecdhXs, ecdhYc, ext_supported) + ext_supported = map( + lambda x: 'uncompressed' if + x == ECPointFormat.uncompressed else + 'compressed', ext_supported + ) + return kex.calc_shared_key(self.ecdhXs, ecdhYc, set(ext_supported)) def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Process the server key exchange, return premaster secret""" @@ -783,8 +791,15 @@ def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): raise TLSIllegalParameterException( "No common EC point format") + ext_negotiated = 'uncompressed' if \ + ext_negotiated == ECPointFormat.uncompressed else 'compressed' + ext_supported = map( + lambda x: 'uncompressed' if + x == ECPointFormat.uncompressed else + 'compressed', ext_supported + ) self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated) - return kex.calc_shared_key(ecdhXc, ecdh_Ys, ext_supported) + return kex.calc_shared_key(ecdhXc, ecdh_Ys, set(ext_supported)) def makeClientKeyExchange(self): """Make client key exchange for ECDHE""" @@ -931,11 +946,11 @@ def get_random_private_key(self): """ raise NotImplementedError("Abstract class") - def calc_public_value(self, private, frm_negotiated=None): + def calc_public_value(self, private, point_format=None): """Calculate the public value from the provided private value.""" raise NotImplementedError("Abstract class") - def calc_shared_key(self, private, peer_share, frm_supported=None): + def calc_shared_key(self, private, peer_share, valid_point_formats=None): """Calcualte the shared key given our private and remote share value""" raise NotImplementedError("Abstract class") @@ -968,10 +983,11 @@ def get_random_private_key(self): needed_bytes = divceil(paramStrength(self.prime) * 2, 8) return bytesToNumber(getRandomBytes(needed_bytes)) - def calc_public_value(self, private, frm_negotiated=None): + def calc_public_value(self, private, point_format=None): """ Calculate the public value for given private value. - Frm_negotiated added for API compatibility, not needed for FFDH. + + :param point_format: ignored, used for compatibility with ECDH groups :rtype: int """ @@ -993,9 +1009,10 @@ def _normalise_peer_share(self, peer_share): "Key share does not match FFDH prime") return bytesToNumber(peer_share) - def calc_shared_key(self, private, peer_share, frm_supported=None): + def calc_shared_key(self, private, peer_share, valid_point_formats=None): """Calculate the shared key. - Frm_supported added for API compatibility, not needed for FFDH. + + :param valid_point_formats: ignored, used for compatibility with ECDH groups :rtype: bytearray""" peer_share = self._normalise_peer_share(peer_share) @@ -1052,51 +1069,46 @@ def _get_fun_gen_size(self): else: return x448, bytearray(X448_G), X448_ORDER_SIZE - @staticmethod - def _get_point_format(ext): - """Get point format from the numeric value.""" - transform = {ECPointFormat.uncompressed: 'uncompressed', - ECPointFormat.ansiX962_compressed_char2: 'compressed', - ECPointFormat.ansiX962_compressed_prime: 'compressed'} - return transform[ext] - def calc_public_value(self, private, - frm_negotiated=ECPointFormat.uncompressed): - """Calculate public value for given private key.""" - point_fmt = self._get_point_format(frm_negotiated) + point_format='uncompressed'): + """Calculate public value for given private key. + + :param private: Private key for the selected key exchange group. + + :param str point_format: The point format to use for the + ECDH public key. Applies only to NIST curves. + """ if isinstance(private, ecdsa.keys.SigningKey): - return private.verifying_key.to_string(point_fmt) + return private.verifying_key.to_string(point_format) if self.group in self._x_groups: fun, generator, _ = self._get_fun_gen_size() return fun(private, generator) curve = getCurveByName(GroupName.toStr(self.group)) point = curve.generator * private - return bytearray(point.to_bytes(encoding=point_fmt)) + return bytearray(point.to_bytes(encoding=point_format)) def calc_shared_key(self, private, peer_share, - frm_supported=set([ECPointFormat.uncompressed])): + valid_point_formats=set(['uncompressed'])): """Calculate the shared key. - :type private: bytearray | SigningKey - :param private: private value + :param bytearray | SigningKey private: private value - :type peer_share: bytearray - :param peer_share: public value + :param bytearray peer_share: public value - :type frm_supported: set(ECPointFormat) - :param frm_supported: acceptable point formats for public value + :param set(str) valid_point_formats: list of point formats that + the peer share can be in; ["uncompressed"] by default. :rtype: bytearray :returns: shared key :raises TLSIllegalParameterException when the paramentrs for point are invalid - """ - valid_encodings = set([self._get_point_format(i) \ - for i in frm_supported]) + :raises TLSDecodeError + when the the valid_point_formats is empty + """ if self.group in self._x_groups: fun, _, size = self._get_fun_gen_size() if len(peer_share) != size: @@ -1111,17 +1123,19 @@ def calc_shared_key(self, private, peer_share, try: abstractPoint = ecdsa.ellipticcurve.AbstractPoint() point = abstractPoint.from_bytes(curve.curve, peer_share, - valid_encodings=valid_encodings) + valid_encodings=valid_point_formats) ecdhYc = ecdsa.ellipticcurve.Point( curve.curve, point[0], point[1]) - except (AssertionError, DecodeError): + except (AssertionError): raise TLSIllegalParameterException("Invalid ECC point") + except DecodeError as err: + raise TLSDecodeError(f"Unexpected error {err=}, {type(err)=}") from err if isinstance(private, ecdsa.keys.SigningKey): ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private) ecdh.load_received_public_key_bytes(peer_share, valid_encodings= - valid_encodings) + valid_point_formats) return bytearray(ecdh.generate_sharedsecret_bytes()) S = ecdhYc * private diff --git a/tlslite/session.py b/tlslite/session.py index eebe64a1..c080e373 100644 --- a/tlslite/session.py +++ b/tlslite/session.py @@ -97,7 +97,7 @@ def __init__(self): self.resumptionMasterSecret = bytearray(0) self.tickets = None self.tls_1_0_tickets = None - self.ec_point_format = None + self.ec_point_format = 0 def create(self, masterSecret, sessionID, cipherSuite, srpUsername, clientCertChain, serverCertChain, diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index f4ff28ff..c621f2c7 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -3242,7 +3242,8 @@ def _ticket_to_session(self, settings, ticket_ext): serverName=ticket.server_name.decode("utf-8") if ticket.server_name else "", encryptThenMAC=ticket.encrypt_then_mac, - extendedMasterSecret=ticket.extended_master_secret) + extendedMasterSecret=ticket.extended_master_secret, + ec_point_format=0) return session def _serverGetClientHello(self, settings, private_key, cert_chain,