From 1a2a1561a4dad90cdc5384fd351f747a9cd6c960 Mon Sep 17 00:00:00 2001 From: gstarovo Date: Wed, 17 Apr 2024 11:19:54 +0200 Subject: [PATCH] advertising ec point extension format check if the client adverties the uncompressed point format extension error when uncompressed is not supported fix: changes in accepting the format(form ECPointFormat to string) of ec format. fix: tests, keyShares added to tests for checking ecc point extension --- .gitignore | 2 +- scripts/tls.py | 1 + test | 0 tests/tlstest.py | 148 ++++++++++++++++++++++++- tlslite/handshakesettings.py | 18 ++- tlslite/keyexchange.py | 115 ++++++++++++++----- tlslite/session.py | 16 ++- tlslite/tlsconnection.py | 55 +++++++-- unit_tests/test_tlslite_keyexchange.py | 10 +- 9 files changed, 314 insertions(+), 51 deletions(-) delete mode 100644 test diff --git a/.gitignore b/.gitignore index 56433754..daedfe76 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,4 @@ coverage.xml pylint_report.txt build/ docs/_build/ -htmlcov/ +htmlcov/ \ No newline at end of file diff --git a/scripts/tls.py b/scripts/tls.py index 3159cbc3..1e174e2e 100755 --- a/scripts/tls.py +++ b/scripts/tls.py @@ -399,6 +399,7 @@ def printGoodConnection(connection, seconds): if connection.server_cert_compression_algo: print(" Server compression algorithm used: {0}".format( connection.server_cert_compression_algo)) + print(" Session used ec point format extension: {0}".format(connection.session.ec_point_format)) def printExporter(connection, expLabel, expLength): if expLabel is None: diff --git a/test b/test deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/tlstest.py b/tests/tlstest.py index 9ce40f4d..989e9793 100755 --- a/tests/tlstest.py +++ b/tests/tlstest.py @@ -44,7 +44,7 @@ from xmlrpc import client as xmlrpclib import ssl from tlslite import * -from tlslite.constants import KeyUpdateMessageType +from tlslite.constants import KeyUpdateMessageType, ECPointFormat try: from tack.structures.Tack import Tack @@ -303,6 +303,76 @@ def connect(): test_no += 1 + print("Test {0} - client compressed/uncompressed - uncompressed, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.ec_point_format == ECPointFormat.uncompressed + connection.close() + + test_no += 1 + + print("Test {0} - client compressed - compressed, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + settings.keyShares = ["secp256r1"] + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_prime + connection.close() + + test_no += 1 + + print("Test {0} - client missing uncompressed - error, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.ec_point_formats = [ECPointFormat.ansiX962_compressed_prime] + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + settings.keyShares = ["secp256r1"] + try: + connection.handshakeClientCert(settings=settings) + assert False + except ValueError as e: + assert "Uncompressed EC point format is not provided" in str(e) + except TLSAbruptCloseError as e: + pass + connection.close() + + test_no += 1 + + print("Test {0} - client comppressed char2 - error, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.ec_point_formats = [ECPointFormat.ansiX962_compressed_char2] + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + settings.keyShares = ["secp256r1"] + try: + connection.handshakeClientCert(settings=settings) + assert False + except ValueError as e: + assert "Unknown EC point format provided: [2]" in str(e) + except TLSAbruptCloseError as e: + pass + connection.close() + + test_no += 1 + print("Test {0} - mismatched ECDSA curve, TLSv1.2".format(test_no)) synchro.recv(1) connection = connect() @@ -2194,6 +2264,79 @@ def connect(): test_no += 1 + print("Test {0} - server uncompressed ec format - uncompressed, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + settings.keyShares = ["secp256r1"] + settings.ec_point_formats = [ECPointFormat.uncompressed] + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + testConnServer(connection) + assert connection.session.ec_point_format == ECPointFormat.uncompressed + connection.close() + + test_no += 1 + + print("Test {0} - server compressed ec format - compressed, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + settings.keyShares = ["secp256r1"] + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + testConnServer(connection) + assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_prime + connection.close() + + test_no +=1 + + print("Test {0} - server missing uncompressed in client - error, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + settings.keyShares = ["secp256r1"] + try: + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + assert False + except ValueError as e: + assert "Uncompressed EC point format is not provided" in str(e) + except TLSAbruptCloseError as e: + pass + connection.close() + + test_no +=1 + + print("Test {0} - client compressed char2 - error, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + settings.keyShares = ["secp256r1"] + try: + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + assert False + except ValueError as e: + assert "Unknown EC point format provided: [2]" in str(e) + except TLSAbruptCloseError as e: + pass + connection.close() + + test_no +=1 + print("Test {0} - mismatched ECDSA curve, TLSv1.2".format(test_no)) synchro.send(b'R') connection = connect() @@ -3450,7 +3593,7 @@ def heartbeat_response_check(message): assert synchro.recv(1) == b'R' connection.close() - test_no += 1 + test_no +=1 print("Tests {0}-{1} - XMLRPXC server".format(test_no, test_no + 2)) @@ -3483,6 +3626,7 @@ def add(self, x, y): return x + y synchro.close() synchroSocket.close() + test_no += 2 print("Test succeeded") diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py index 2ccb0841..7cc51045 100644 --- a/tlslite/handshakesettings.py +++ b/tlslite/handshakesettings.py @@ -7,7 +7,7 @@ """Class for setting handshake parameters.""" -from .constants import CertificateType +from .constants import CertificateType, ECPointFormat from .utils import cryptomath from .utils import cipherfactory from .utils.compat import ecdsaAllCurves, int_types, ML_KEM_AVAILABLE @@ -67,6 +67,8 @@ TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm", "aes128ccm_8", "aes256ccm", "aes256ccm_8"] PSK_MODES = ["psk_dhe_ke", "psk_ke"] +EC_POINT_FORMATS = [ECPointFormat.ansiX962_compressed_prime, + ECPointFormat.uncompressed] ALL_COMPRESSION_ALGOS_SEND = ["zlib"] if compression_algo_impls["brotli_compress"]: @@ -385,6 +387,10 @@ class HandshakeSettings(object): option is for when a certificate was received/decompressed by this peer. + + :vartype ec_point_formats: list + :ivar ec_point_formats: Enabled point format extension for + elliptic curves. """ def _init_key_settings(self): @@ -432,6 +438,7 @@ def _init_misc_extensions(self): # resumed connections (as tickets are single-use in TLS 1.3 self.ticket_count = 2 self.record_size_limit = 2**14 + 1 # TLS 1.3 includes content type + self.ec_point_formats = list(EC_POINT_FORMATS) # Certificate compression self.certificate_compression_send = list(ALL_COMPRESSION_ALGOS_SEND) @@ -642,6 +649,14 @@ def _sanityCheckExtensions(other): not 64 <= other.record_size_limit <= 2**14 + 1: raise ValueError("record_size_limit cannot exceed 2**14+1 bytes") + bad_ec_ext = [i for i in other.ec_point_formats if + i not in EC_POINT_FORMATS] + if bad_ec_ext: + raise ValueError("Unknown EC point format provided: " + "{0}".format(bad_ec_ext)) + if ECPointFormat.uncompressed not in other.ec_point_formats: + raise ValueError("Uncompressed EC point format is not provided") + HandshakeSettings._sanityCheckEMSExtension(other) if other.certificate_compression_send: @@ -736,6 +751,7 @@ def _copy_extension_settings(self, other): other.sendFallbackSCSV = self.sendFallbackSCSV other.useEncryptThenMAC = self.useEncryptThenMAC other.usePaddingExtension = self.usePaddingExtension + other.ec_point_formats = self.ec_point_formats # session tickets other.padding_cb = self.padding_cb other.ticketKeys = self.ticketKeys diff --git a/tlslite/keyexchange.py b/tlslite/keyexchange.py index 6c49a975..8b726c1e 100644 --- a/tlslite/keyexchange.py +++ b/tlslite/keyexchange.py @@ -12,7 +12,7 @@ TLSDecodeError from .messages import ServerKeyExchange, ClientKeyExchange, CertificateVerify from .constants import SignatureAlgorithm, HashAlgorithm, CipherSuite, \ - ExtensionType, GroupName, ECCurveType, SignatureScheme + ExtensionType, GroupName, ECCurveType, SignatureScheme, ECPointFormat from .utils.ecc import getCurveByName, getPointByteSize from .utils.rsakey import RSAKey from .utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \ @@ -709,14 +709,25 @@ def makeServerKeyExchange(self, sigHash=None): kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) self.ecdhXs = kex.get_random_private_key() - if isinstance(self.ecdhXs, ecdsa.keys.SigningKey): - ecdhYs = bytearray( - self.ecdhXs.get_verifying_key().to_string( - encoding = 'uncompressed' - ) - ) - else: - ecdhYs = kex.calc_public_value(self.ecdhXs) + ext_negotiated = ECPointFormat.uncompressed + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c: + if ECPointFormat.uncompressed not in ext_c.formats: + raise TLSIllegalParameterException( + "The client does not advertise " + "the uncompressed point format extension.") + if ext_c and ext_s: + try: + ext_negotiated = next((i for i in ext_c.formats \ + if i in ext_s.formats)) + except StopIteration: + raise TLSIllegalParameterException("No common EC point format") + + ext_negotiated = 'uncompressed' if \ + ext_negotiated == ECPointFormat.uncompressed else 'compressed' + + ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated) version = self.serverHello.server_version serverKeyExchange = ServerKeyExchange(self.cipherSuite, version) @@ -734,7 +745,21 @@ def processClientKeyExchange(self, clientKeyExchange): raise TLSDecodeError("No key share") kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) - return kex.calc_shared_key(self.ecdhXs, ecdhYc) + ext_supported = [ECPointFormat.uncompressed] + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + ext_supported = [ + ext for ext in ext_c.formats if ext in ext_s.formats + ] + if not ext_supported: + raise TLSIllegalParameterException("No common EC point format") + ext_supported = map( + lambda x: 'uncompressed' if + x == ECPointFormat.uncompressed else + 'compressed', ext_supported + ) + return kex.calc_shared_key(self.ecdhXs, ecdhYc, set(ext_supported)) def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Process the server key exchange, return premaster secret""" @@ -752,15 +777,33 @@ def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): kex = ECDHKeyExchange(serverKeyExchange.named_curve, self.serverHello.server_version) ecdhXc = kex.get_random_private_key() - if isinstance(ecdhXc, ecdsa.keys.SigningKey): - self.ecdhYc = bytearray( - ecdhXc.get_verifying_key().to_string( - encoding = 'uncompressed' - ) - ) - else: - self.ecdhYc = kex.calc_public_value(ecdhXc) - return kex.calc_shared_key(ecdhXc, ecdh_Ys) + ext_negotiated = ECPointFormat.uncompressed + ext_supported = [ECPointFormat.uncompressed] + + if self.clientHello: + ext_c = self.clientHello.getExtension( + ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension( + ExtensionType.ec_point_formats) + if ext_c and ext_s: + try: + ext_supported = [ + i for i in ext_c.formats if i in ext_s.formats + ] + ext_negotiated = ext_supported[0] + except IndexError: + raise TLSIllegalParameterException( + "No common EC point format") + + ext_negotiated = 'uncompressed' if \ + ext_negotiated == ECPointFormat.uncompressed else 'compressed' + ext_supported = map( + lambda x: 'uncompressed' if + x == ECPointFormat.uncompressed else + 'compressed', ext_supported + ) + self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated) + return kex.calc_shared_key(ecdhXc, ecdh_Ys, set(ext_supported)) def makeClientKeyExchange(self): """Make client key exchange for ECDHE""" @@ -911,7 +954,7 @@ def calc_public_value(self, private, point_format=None): """Calculate the public value from the provided private value.""" raise NotImplementedError("Abstract class") - def calc_shared_key(self, private, peer_share): + def calc_shared_key(self, private, peer_share, valid_point_formats=None): """Calcualte the shared key given our private and remote share value""" raise NotImplementedError("Abstract class") @@ -949,6 +992,7 @@ def calc_public_value(self, private, point_format=None): Calculate the public value for given private value. :param point_format: ignored, used for compatibility with ECDH groups + :rtype: int """ dh_Y = powMod(self.generator, private, self.prime) @@ -969,8 +1013,12 @@ def _normalise_peer_share(self, peer_share): "Key share does not match FFDH prime") return bytesToNumber(peer_share) - def calc_shared_key(self, private, peer_share): - """Calculate the shared key.""" + def calc_shared_key(self, private, peer_share, valid_point_formats=None): + """Calculate the shared key. + + :param valid_point_formats: ignored, used for compatibility with ECDH groups + + :rtype: bytearray""" peer_share = self._normalise_peer_share(peer_share) # First half of RFC 2631, Section 2.1.5. Validate the client's public # key. @@ -989,7 +1037,6 @@ def calc_shared_key(self, private, peer_share): class ECDHKeyExchange(RawDHKeyExchange): """Implementation of the Elliptic Curve Diffie-Hellman key exchange.""" - _x_groups = set((GroupName.x25519, GroupName.x448)) @staticmethod @@ -1049,10 +1096,22 @@ def calc_shared_key(self, private, peer_share, """ Calculate the shared key. + :param bytearray | SigningKey private: private value + + :param bytearray peer_share: public value + :param set(str) valid_point_formats: list of point formats that the peer share can be in; ["uncompressed"] by default. - """ + :rtype: bytearray + :returns: shared key + + :raises TLSIllegalParameterException + when the paramentrs for point are invalid + + :raises TLSDecodeError + when the the valid_point_formats is empty + """ if self.group in self._x_groups: fun, _, size = self._get_fun_gen_size() if len(peer_share) != size: @@ -1071,11 +1130,15 @@ def calc_shared_key(self, private, peer_share, ecdhYc = ecdsa.ellipticcurve.Point( curve.curve, point[0], point[1]) - except (AssertionError, DecodeError): + except AssertionError: raise TLSIllegalParameterException("Invalid ECC point") + except DecodeError: + raise TLSDecodeError("Unexpected error") if isinstance(private, ecdsa.keys.SigningKey): ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private) - ecdh.load_received_public_key_bytes(peer_share) + ecdh.load_received_public_key_bytes(peer_share, + valid_encodings= + valid_point_formats) return bytearray(ecdh.generate_sharedsecret_bytes()) S = ecdhYc * private diff --git a/tlslite/session.py b/tlslite/session.py index 0e310b71..c080e373 100644 --- a/tlslite/session.py +++ b/tlslite/session.py @@ -1,4 +1,4 @@ -# Authors: +# Authors: # Trevor Perrin # Dave Baggett (Arcode Corporation) - canonicalCipherName # @@ -72,6 +72,9 @@ class Session(object): :vartype tls_1_0_tickets: list :ivar tls_1_0_tickets: list of TLS 1.2 and earlier session tickets received from the server + + :vartype ec_point_format: int + :ivar ec_point_format: used EC point format for the ECDH key exchange; """ def __init__(self): @@ -94,6 +97,7 @@ def __init__(self): self.resumptionMasterSecret = bytearray(0) self.tickets = None self.tls_1_0_tickets = None + self.ec_point_format = 0 def create(self, masterSecret, sessionID, cipherSuite, srpUsername, clientCertChain, serverCertChain, @@ -102,7 +106,7 @@ def create(self, masterSecret, sessionID, cipherSuite, appProto=bytearray(0), cl_app_secret=bytearray(0), sr_app_secret=bytearray(0), exporterMasterSecret=bytearray(0), resumptionMasterSecret=bytearray(0), tickets=None, - tls_1_0_tickets=None): + tls_1_0_tickets=None, ec_point_format=None): self.masterSecret = masterSecret self.sessionID = sessionID self.cipherSuite = cipherSuite @@ -110,7 +114,7 @@ def create(self, masterSecret, sessionID, cipherSuite, self.clientCertChain = clientCertChain self.serverCertChain = serverCertChain self.tackExt = tackExt - self.tackInHelloExt = tackInHelloExt + self.tackInHelloExt = tackInHelloExt self.serverName = serverName self.resumable = resumable self.encryptThenMAC = encryptThenMAC @@ -123,6 +127,7 @@ def create(self, masterSecret, sessionID, cipherSuite, # NOTE we need a reference copy not a copy of object here! self.tickets = tickets self.tls_1_0_tickets = tls_1_0_tickets + self.ec_point_format = ec_point_format def _clone(self): other = Session() @@ -145,6 +150,7 @@ def _clone(self): other.resumptionMasterSecret = self.resumptionMasterSecret other.tickets = self.tickets other.tls_1_0_tickets = self.tls_1_0_tickets + other.ec_point_format = self.ec_point_format return other def valid(self): @@ -167,7 +173,7 @@ def getTackId(self): return self.tackExt.tack.getTackId() else: return None - + def getBreakSigs(self): if self.tackExt and self.tackExt.break_sigs: return self.tackExt.break_sigs @@ -181,7 +187,7 @@ def getCipherName(self): :returns: The name of the cipher used with this connection. """ return CipherSuite.canonicalCipherName(self.cipherSuite) - + def getMacName(self): """Get the name of the HMAC hash algo used with this connection. diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index abb7ce83..ac108da1 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -671,6 +671,21 @@ def _handshakeClientAsyncHelper(self, srpParams, certParams, anonParams, if alpnExt: alpnProto = alpnExt.protocol_names[0] + ext_ec_point = ECPointFormat.uncompressed + if self.version < (3, 4): + ext_c = clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + try: + ext_ec_point = next((i for i in ext_c.formats \ + if i in ext_s.formats)) + + except StopIteration as alert: + for result in self._sendError( + AlertDescription.illegal_parameter, + str(alert)): + yield result + # Create the session object which is used for resumptions self.session = Session() self.session.create(masterSecret, serverHello.session_id, cipherSuite, @@ -682,7 +697,8 @@ def _handshakeClientAsyncHelper(self, srpParams, certParams, anonParams, appProto=alpnProto, # NOTE it must be a reference not a copy tickets=self.tickets, - tls_1_0_tickets=self.tls_1_0_tickets) + tls_1_0_tickets=self.tls_1_0_tickets, + ec_point_format=ext_ec_point) self._handshakeDone(resumed=False) self._serverRandom = serverHello.random self._clientRandom = clientHello.random @@ -760,7 +776,6 @@ def _clientSendClientHello(self, settings, session, srpUsername, for group_name in settings.keyShares: group_id = getattr(GroupName, group_name) key_share = self._genKeyShareEntry(group_id, (3, 4)) - shares.append(key_share) # if TLS 1.3 is enabled, key_share must always be sent # (unless only static PSK is used) @@ -777,8 +792,9 @@ def _clientSendClientHello(self, settings, session, srpUsername, if next((cipher for cipher in cipherSuites \ if cipher in CipherSuite.ecdhAllSuites), None) is not None: groups.extend(self._curveNamesToList(settings)) - extensions.append(ECPointFormatsExtension().\ - create([ECPointFormat.uncompressed])) + if settings.ec_point_formats: + extensions.append(ECPointFormatsExtension().\ + create(settings.ec_point_formats)) # Advertise FFDHE groups if we have DHE ciphers if next((cipher for cipher in cipherSuites if cipher in CipherSuite.dhAllSuites), None) is not None: @@ -941,6 +957,7 @@ def _clientGetServerHello(self, settings, session, clientHello): hello_retry = None ext = result.getExtension(ExtensionType.supported_versions) + if result.random == TLS_1_3_HRR and ext and ext.version > (3, 3): self.version = ext.version hello_retry = result @@ -1000,7 +1017,6 @@ def _clientGetServerHello(self, settings, session, clientHello): "did sent the key share " "for"): yield result - key_share = self._genKeyShareEntry(group_id, (3, 4)) # old key shares need to be removed @@ -1249,7 +1265,6 @@ def _clientTLS13Handshake(self, settings, session, clientHello, raise TLSIllegalParameterException("Server selected not " "advertised group.") kex = self._getKEX(sr_kex.group, self.version) - shared_sec = kex.calc_shared_key(cl_kex.private, sr_kex.key_exchange) else: @@ -2070,7 +2085,7 @@ def _clientGetKeyFromChain(self, certificate, settings, tack_ext=None): def handshakeServer(self, verifierDB=None, certChain=None, privateKey=None, reqCert=False, sessionCache=None, settings=None, checker=None, - reqCAs=None, + reqCAs = None, tacks=None, activationFlags=0, nextProtos=None, anon=False, alpn=None, sni=None): """Perform a handshake in the role of server. @@ -2351,8 +2366,9 @@ def _handshakeServerAsyncHelper(self, verifierDB, if clientHello.getExtension(ExtensionType.ec_point_formats): # even though the selected cipher may not use ECC, client may want # to send a CA certificate with ECDSA... - extensions.append(ECPointFormatsExtension().create( - [ECPointFormat.uncompressed])) + if settings.ec_point_formats: + extensions.append(ECPointFormatsExtension(). + create(settings.ec_point_formats)) # if client sent Heartbeat extension if clientHello.getExtension(ExtensionType.heartbeat): @@ -2496,6 +2512,21 @@ def _handshakeServerAsyncHelper(self, verifierDB, if clientHello.server_name: serverName = clientHello.server_name.decode("utf-8") + ext_ec_point = ECPointFormat.uncompressed + if version < (3, 4): + ext_c = clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + try: + ext_ec_point = next((i for i in ext_c.formats \ + if i in ext_s.formats)) + + except StopIteration as alert: + for result in self._sendError( + AlertDescription.illegal_parameter, + str(alert)): + yield result + # We'll update the session master secret once it is calculated # in _serverFinished self.session.create(b"", serverHello.session_id, cipherSuite, @@ -2507,7 +2538,8 @@ def _handshakeServerAsyncHelper(self, verifierDB, extendedMasterSecret=self.extendedMasterSecret, appProto=selectedALPN, # NOTE it must be a reference, not a copy! - tickets=self.tickets) + tickets=self.tickets, + ec_point_format=ext_ec_point) # Exchange Finished messages for result in self._serverFinished(premasterSecret, @@ -3232,7 +3264,8 @@ def _ticket_to_session(self, settings, ticket_ext): serverName=ticket.server_name.decode("utf-8") if ticket.server_name else "", encryptThenMAC=ticket.encrypt_then_mac, - extendedMasterSecret=ticket.extended_master_secret) + extendedMasterSecret=ticket.extended_master_secret, + ec_point_format=0) return session def _serverGetClientHello(self, settings, private_key, cert_chain, diff --git a/unit_tests/test_tlslite_keyexchange.py b/unit_tests/test_tlslite_keyexchange.py index ee142620..b960f94d 100644 --- a/unit_tests/test_tlslite_keyexchange.py +++ b/unit_tests/test_tlslite_keyexchange.py @@ -18,16 +18,16 @@ from tlslite.handshakesettings import HandshakeSettings from tlslite.messages import ServerHello, ClientHello, ServerKeyExchange,\ CertificateRequest, ClientKeyExchange -from tlslite.constants import CipherSuite, CertificateType, AlertDescription, \ +from tlslite.constants import CipherSuite, CertificateType, \ HashAlgorithm, SignatureAlgorithm, GroupName, ECCurveType, \ SignatureScheme -from tlslite.errors import TLSLocalAlert, TLSIllegalParameterException, \ +from tlslite.errors import TLSIllegalParameterException, \ TLSDecryptionFailed, TLSInsufficientSecurity, TLSUnknownPSKIdentity, \ TLSInternalError, TLSDecodeError from tlslite.x509 import X509 from tlslite.x509certchain import X509CertChain from tlslite.utils.keyfactory import parsePEMKey -from tlslite.utils.codec import Parser, Writer +from tlslite.utils.codec import Parser from tlslite.utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \ numberToByteArray, isPrime, numBytes from tlslite.mathtls import makeX, makeU, makeK, goodGroupParameters @@ -2523,13 +2523,13 @@ def test_calc_public_value(self): kex = RawDHKeyExchange(None, None) with self.assertRaises(NotImplementedError): - kex.calc_public_value(None) + kex.calc_public_value(None, None) def test_calc_shared_value(self): kex = RawDHKeyExchange(None, None) with self.assertRaises(NotImplementedError): - kex.calc_shared_key(None, None) + kex.calc_shared_key(None, None, None) class TestFFDHKeyExchange(unittest.TestCase):