From 76cc7fc28823c89878beec82dfdd422e5b4a6ec6 Mon Sep 17 00:00:00 2001 From: gstarovo Date: Wed, 17 Apr 2024 11:19:54 +0200 Subject: [PATCH] changes in point extension format --- .github/workflows/ci.yml | 8 ++- tlslite/handshakesettings.py | 17 +++++- tlslite/keyexchange.py | 112 ++++++++++++++++++++++++----------- tlslite/tlsconnection.py | 19 +++--- 4 files changed, 111 insertions(+), 45 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6581b6fea..64d4eb2b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -433,7 +433,13 @@ jobs: COVERALLS_FLAG_NAME: ${{ matrix.name }} COVERALLS_PARALLEL: true COVERALLS_SERVICE_NAME: github - run: coveralls + PY_VERSION: ${{ matrix.python-version }} + run: | + if [[ $PY_VERSION == "2.6" ]]; then + COVERALLS_SKIP_SSL_VERIFY=1 coveralls + else + coveralls + fi - name: Publish coverage to Codeclimate if: ${{ contains(matrix.opt-deps, 'codeclimate') }} env: diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py index 38e560a2b..8eab3db57 100644 --- a/tlslite/handshakesettings.py +++ b/tlslite/handshakesettings.py @@ -7,7 +7,7 @@ """Class for setting handshake parameters.""" -from .constants import CertificateType +from .constants import CertificateType, ECPointFormat from .utils import cryptomath from .utils import cipherfactory from .utils.compat import ecdsaAllCurves, int_types @@ -61,6 +61,9 @@ TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm", "aes128ccm_8", "aes256ccm", "aes256ccm_8"] PSK_MODES = ["psk_dhe_ke", "psk_ke"] +EC_POINT_FORMATS = [ECPointFormat.uncompressed, + ECPointFormat.ansiX962_compressed_char2, + ECPointFormat.ansiX962_compressed_prime] class Keypair(object): @@ -353,6 +356,10 @@ class HandshakeSettings(object): :vartype keyExchangeNames: list :ivar keyExchangeNames: Enabled key exchange types for the connection, influences selected cipher suites. + + :vartype ec_point_formats: list + :ivat ec_point_formats: Enabeled point format extension for + elliptic curves. """ def _init_key_settings(self): @@ -396,6 +403,7 @@ def _init_misc_extensions(self): # resumed connections (as tickets are single-use in TLS 1.3 self.ticket_count = 2 self.record_size_limit = 2**14 + 1 # TLS 1.3 includes content type + self.ec_point_formats = EC_POINT_FORMATS def __init__(self): """Initialise default values for settings.""" @@ -599,6 +607,12 @@ def _sanityCheckExtensions(other): not 64 <= other.record_size_limit <= 2**14 + 1: raise ValueError("record_size_limit cannot exceed 2**14+1 bytes") + bad_ec_ext = [i for i in other.ec_point_formats if + i not in EC_POINT_FORMATS] + if bad_ec_ext: + raise ValueError("Unknown ec point format extension: " + "{0}".format(bad_ec_ext)) + HandshakeSettings._sanityCheckEMSExtension(other) @staticmethod @@ -667,6 +681,7 @@ def _copy_extension_settings(self, other): other.sendFallbackSCSV = self.sendFallbackSCSV other.useEncryptThenMAC = self.useEncryptThenMAC other.usePaddingExtension = self.usePaddingExtension + other.ec_point_formats = self.ec_point_formats # session tickets other.padding_cb = self.padding_cb other.ticketKeys = self.ticketKeys diff --git a/tlslite/keyexchange.py b/tlslite/keyexchange.py index 2242aad3e..dfe0a7e36 100644 --- a/tlslite/keyexchange.py +++ b/tlslite/keyexchange.py @@ -12,7 +12,7 @@ TLSDecodeError from .messages import ServerKeyExchange, ClientKeyExchange, CertificateVerify from .constants import SignatureAlgorithm, HashAlgorithm, CipherSuite, \ - ExtensionType, GroupName, ECCurveType, SignatureScheme + ExtensionType, GroupName, ECCurveType, SignatureScheme, ECPointFormat from .utils.ecc import getCurveByName, getPointByteSize from .utils.rsakey import RSAKey from .utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \ @@ -704,15 +704,16 @@ def makeServerKeyExchange(self, sigHash=None): kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) self.ecdhXs = kex.get_random_private_key() + ext_negotiated = ECPointFormat.uncompressed + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + for ext in ext_c.formats: + if ext in ext_s.formats: + ext_negotiated = ext + break - if isinstance(self.ecdhXs, ecdsa.keys.SigningKey): - ecdhYs = bytearray( - self.ecdhXs.get_verifying_key().to_string( - encoding = 'uncompressed' - ) - ) - else: - ecdhYs = kex.calc_public_value(self.ecdhXs) + ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated) version = self.serverHello.server_version serverKeyExchange = ServerKeyExchange(self.cipherSuite, version) @@ -730,7 +731,15 @@ def processClientKeyExchange(self, clientKeyExchange): raise TLSDecodeError("No key share") kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) - return kex.calc_shared_key(self.ecdhXs, ecdhYc) + ext_supported = [ECPointFormat.uncompressed] + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + ext_supported = [] + for ext in ext_c.formats: + if ext in ext_s.formats: + ext_supported.append(ext) + return kex.calc_shared_key(self.ecdhXs, ecdhYc, ext_supported) def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Process the server key exchange, return premaster secret""" @@ -748,15 +757,17 @@ def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): kex = ECDHKeyExchange(serverKeyExchange.named_curve, self.serverHello.server_version) ecdhXc = kex.get_random_private_key() - if isinstance(ecdhXc, ecdsa.keys.SigningKey): - self.ecdhYc = bytearray( - ecdhXc.get_verifying_key().to_string( - encoding = 'uncompressed' - ) - ) - else: - self.ecdhYc = kex.calc_public_value(ecdhXc) - return kex.calc_shared_key(ecdhXc, ecdh_Ys) + ext_negotiated = ECPointFormat.uncompressed + ext_supported = [ECPointFormat.uncompressed] + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + ext_supported = [i for i in ext_c.formats if i in ext_s.formats] + if not ext_supported: + raise TLSDecodeError("No negotiated ec point extension.") + ext_negotiated = ext_supported[0] + self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated) + return kex.calc_shared_key(ecdhXc, ecdh_Ys, ext_supported) def makeClientKeyExchange(self): """Make client key exchange for ECDHE""" @@ -903,11 +914,11 @@ def get_random_private_key(self): """ raise NotImplementedError("Abstract class") - def calc_public_value(self, private): + def calc_public_value(self, private, frm_negotiated=None): """Calculate the public value from the provided private value.""" raise NotImplementedError("Abstract class") - def calc_shared_key(self, private, peer_share): + def calc_shared_key(self, private, peer_share, frm_supported=None): """Calcualte the shared key given our private and remote share value""" raise NotImplementedError("Abstract class") @@ -940,9 +951,10 @@ def get_random_private_key(self): needed_bytes = divceil(paramStrength(self.prime) * 2, 8) return bytesToNumber(getRandomBytes(needed_bytes)) - def calc_public_value(self, private): + def calc_public_value(self, private, frm_negotiated=None): """ Calculate the public value for given private value. + Frm_negotiated added for API compatibility, not needed for FFDH. :rtype: int """ @@ -964,8 +976,11 @@ def _normalise_peer_share(self, peer_share): "Key share does not match FFDH prime") return bytesToNumber(peer_share) - def calc_shared_key(self, private, peer_share): - """Calculate the shared key.""" + def calc_shared_key(self, private, peer_share, frm_supported=None): + """Calculate the shared key. + Frm_supported added for API compatibility, not needed for FFDH. + + :rtype: bytearray""" peer_share = self._normalise_peer_share(peer_share) # First half of RFC 2631, Section 2.1.5. Validate the client's public # key. @@ -984,7 +999,6 @@ def calc_shared_key(self, private, peer_share): class ECDHKeyExchange(RawDHKeyExchange): """Implementation of the Elliptic Curve Diffie-Hellman key exchange.""" - _x_groups = set((GroupName.x25519, GroupName.x448)) @staticmethod @@ -1021,20 +1035,50 @@ def _get_fun_gen_size(self): else: return x448, bytearray(X448_G), X448_ORDER_SIZE - def calc_public_value(self, private): + @staticmethod + def _get_point_format(ext): + """Get extension name from the numeric value.""" + transform = {ECPointFormat.uncompressed: 'uncompressed', + ECPointFormat.ansiX962_compressed_char2: 'compressed', + ECPointFormat.ansiX962_compressed_prime: 'compressed'} + return transform[ext] + + def calc_public_value(self, + private, + frm_negotiated=ECPointFormat.uncompressed): """Calculate public value for given private key.""" + point_fmt = self._get_point_format(frm_negotiated) if isinstance(private, ecdsa.keys.SigningKey): - return private.verifying_key.to_string('uncompressed') + return private.verifying_key.to_string(point_fmt) if self.group in self._x_groups: fun, generator, _ = self._get_fun_gen_size() return fun(private, generator) else: curve = getCurveByName(GroupName.toStr(self.group)) point = curve.generator * private - return bytearray(point.to_bytes('uncompressed')) - - def calc_shared_key(self, private, peer_share): - """Calculate the shared key,""" + return bytearray(point.to_bytes(encoding=point_fmt)) + + def calc_shared_key(self, private, peer_share, + frm_supported=set([ECPointFormat.uncompressed])): + """Calculate the shared key. + + :type private: bytearray | SigningKey + :param private: private value + + :type peer_share: bytearray + :param peer_share: public value + + :type frm_supported: set of ECPointFormat + :param frm_supported: supported point format extension for ec + + :rtype: bytearray + :returns: shared key + + :raises TLSIllegalParameterException + when the paramentrs for point are invalid + """ + valid_encodings = set([self._get_point_format(i) \ + for i in frm_supported]) if self.group in self._x_groups: fun, _, size = self._get_fun_gen_size() @@ -1049,7 +1093,8 @@ def calc_shared_key(self, private, peer_share): curve = getCurveByName(GroupName.toRepr(self.group)) try: abstractPoint = ecdsa.ellipticcurve.AbstractPoint() - point = abstractPoint.from_bytes(curve.curve, peer_share) + point = abstractPoint.from_bytes(curve.curve, peer_share, + valid_encodings=valid_encodings) ecdhYc = ecdsa.ellipticcurve.Point( curve.curve, point[0], point[1]) @@ -1057,7 +1102,8 @@ def calc_shared_key(self, private, peer_share): raise TLSIllegalParameterException("Invalid ECC point") if isinstance(private, ecdsa.keys.SigningKey): ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private) - ecdh.load_received_public_key_bytes(peer_share) + ecdh.load_received_public_key_bytes(peer_share, + valid_encodings=valid_encodings) return bytearray(ecdh.generate_sharedsecret_bytes()) S = ecdhYc * private diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index 582097a71..2910bcd72 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -745,7 +745,6 @@ def _clientSendClientHello(self, settings, session, srpUsername, for group_name in settings.keyShares: group_id = getattr(GroupName, group_name) key_share = self._genKeyShareEntry(group_id, (3, 4)) - shares.append(key_share) # if TLS 1.3 is enabled, key_share must always be sent # (unless only static PSK is used) @@ -762,8 +761,9 @@ def _clientSendClientHello(self, settings, session, srpUsername, if next((cipher for cipher in cipherSuites \ if cipher in CipherSuite.ecdhAllSuites), None) is not None: groups.extend(self._curveNamesToList(settings)) - extensions.append(ECPointFormatsExtension().\ - create([ECPointFormat.uncompressed])) + if settings.ec_point_formats: + extensions.append(ECPointFormatsExtension().\ + create(settings.ec_point_formats)) # Advertise FFDHE groups if we have DHE ciphers if next((cipher for cipher in cipherSuites if cipher in CipherSuite.dhAllSuites), None) is not None: @@ -838,7 +838,7 @@ def _clientSendClientHello(self, settings, session, srpUsername, session_id, wireCipherSuites, certificateTypes, srpUsername, - reqTack, nextProtos is not None, + reqTack, nextProtos is not None, serverName, extensions=extensions) @@ -915,6 +915,7 @@ def _clientGetServerHello(self, settings, session, clientHello): hello_retry = None ext = result.getExtension(ExtensionType.supported_versions) + if result.random == TLS_1_3_HRR and ext and ext.version > (3, 3): self.version = ext.version hello_retry = result @@ -974,7 +975,6 @@ def _clientGetServerHello(self, settings, session, clientHello): "did sent the key share " "for"): yield result - key_share = self._genKeyShareEntry(group_id, (3, 4)) # old key shares need to be removed @@ -1855,8 +1855,8 @@ def _clientFinished(self, premasterSecret, clientRandom, serverRandom, cipherSuite, clientRandom, serverRandom) - self._calcPendingStates(cipherSuite, masterSecret, - clientRandom, serverRandom, + self._calcPendingStates(cipherSuite, masterSecret, + clientRandom, serverRandom, cipherImplementations) #Exchange ChangeCipherSpec and Finished messages @@ -2270,8 +2270,8 @@ def _handshakeServerAsyncHelper(self, verifierDB, if clientHello.getExtension(ExtensionType.ec_point_formats): # even though the selected cipher may not use ECC, client may want # to send a CA certificate with ECDSA... - extensions.append(ECPointFormatsExtension().create( - [ECPointFormat.uncompressed])) + extensions.append(ECPointFormatsExtension(). + create(settings.ec_point_formats)) # if client sent Heartbeat extension if clientHello.getExtension(ExtensionType.heartbeat): @@ -2710,7 +2710,6 @@ def _serverTLS13Handshake(self, settings, clientHello, cipherSuite, self.ecdhCurve = selected_group kex = self._getKEX(selected_group, version) key_share = self._genKeyShareEntry(selected_group, version) - try: shared_sec = kex.calc_shared_key(key_share.private, cl_key_share.key_exchange)