Skip to content

Commit

Permalink
Merge pull request #1 from gstarovo/point_ext_changes
Browse files Browse the repository at this point in the history
changes in point extension format
  • Loading branch information
gstarovo authored Apr 17, 2024
2 parents cbb78ab + 08c512e commit dd118da
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 66 deletions.
15 changes: 15 additions & 0 deletions tlslite/handshakesettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm",
"aes128ccm_8", "aes256ccm", "aes256ccm_8"]
PSK_MODES = ["psk_dhe_ke", "psk_ke"]
EC_POINT_FORMATS = [ECPointFormat.uncompressed,
ECPointFormat.ansiX962_compressed_char2,
ECPointFormat.ansiX962_compressed_prime]


class Keypair(object):
Expand Down Expand Up @@ -353,6 +356,10 @@ class HandshakeSettings(object):
:vartype keyExchangeNames: list
:ivar keyExchangeNames: Enabled key exchange types for the connection,
influences selected cipher suites.
:vartype ec_point_formats: list
:ivat ec_point_formats: Enabeled point format extension for
elliptic curves.
"""

def _init_key_settings(self):
Expand Down Expand Up @@ -396,6 +403,7 @@ def _init_misc_extensions(self):
# resumed connections (as tickets are single-use in TLS 1.3
self.ticket_count = 2
self.record_size_limit = 2**14 + 1 # TLS 1.3 includes content type
self.ec_point_formats = EC_POINT_FORMATS

def __init__(self):
"""Initialise default values for settings."""
Expand Down Expand Up @@ -602,6 +610,12 @@ def _sanityCheckExtensions(other):
not 64 <= other.record_size_limit <= 2**14 + 1:
raise ValueError("record_size_limit cannot exceed 2**14+1 bytes")

bad_ec_ext = [i for i in other.ec_point_formats if
i not in EC_POINT_FORMATS]
if bad_ec_ext:
raise ValueError("Unknown ec point format extension: "
"{0}".format(bad_ec_ext))

HandshakeSettings._sanityCheckEMSExtension(other)

@staticmethod
Expand Down Expand Up @@ -670,6 +684,7 @@ def _copy_extension_settings(self, other):
other.sendFallbackSCSV = self.sendFallbackSCSV
other.useEncryptThenMAC = self.useEncryptThenMAC
other.usePaddingExtension = self.usePaddingExtension
other.ec_point_formats = self.ec_point_formats
# session tickets
other.padding_cb = self.padding_cb
other.ticketKeys = self.ticketKeys
Expand Down
89 changes: 57 additions & 32 deletions tlslite/keyexchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
TLSDecodeError
from .messages import ServerKeyExchange, ClientKeyExchange, CertificateVerify
from .constants import SignatureAlgorithm, HashAlgorithm, CipherSuite, \
ExtensionType, GroupName, ECCurveType, SignatureScheme
ExtensionType, GroupName, ECCurveType, SignatureScheme, ECPointFormat
from .utils.ecc import getCurveByName, getPointByteSize
from .utils.rsakey import RSAKey
from .utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \
Expand Down Expand Up @@ -704,14 +704,14 @@ def makeServerKeyExchange(self, sigHash=None):

kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version)
self.ecdhXs = kex.get_random_private_key()
ext_negotiated = 0
ext_negotiated = ECPointFormat.uncompressed
ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats)
if ext_c and ext_s:
ext_negotiated = None
for ext in ext_c.formats:
if ext in ext_s.formats and ext_negotiated is None:
if ext in ext_s.formats:
ext_negotiated = ext
break

ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated)

Expand All @@ -731,7 +731,8 @@ def processClientKeyExchange(self, clientKeyExchange):
raise TLSDecodeError("No key share")

kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version)
ext_supported = [0]

ext_supported = [ECPointFormat.uncompressed]
ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats)
if ext_c and ext_s:
Expand All @@ -757,18 +758,17 @@ def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
kex = ECDHKeyExchange(serverKeyExchange.named_curve,
self.serverHello.server_version)
ecdhXc = kex.get_random_private_key()
ext_negotiated = 0
ext_supported = [0]

ext_negotiated = ECPointFormat.uncompressed
ext_supported = [ECPointFormat.uncompressed]
ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats)
if ext_c and ext_s:
ext_supported = []
for ext in ext_c.formats:
ext_negotiated = None
if ext in ext_s.formats:
ext_supported.append(ext)
if ext_negotiated is None:
ext_negotiated = ext
ext_supported = [i for i in ext_c.formats if i in ext_s.formats]
if not ext_supported:
raise TLSDecodeError("No negotiated ec point extension.")
ext_negotiated = ext_supported[0]

self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated)
return kex.calc_shared_key(ecdhXc, ecdh_Ys, ext_supported)

Expand Down Expand Up @@ -917,11 +917,11 @@ def get_random_private_key(self):
"""
raise NotImplementedError("Abstract class")

def calc_public_value(self, private, frm_negotiated):
def calc_public_value(self, private, frm_negotiated=None):
"""Calculate the public value from the provided private value."""
raise NotImplementedError("Abstract class")

def calc_shared_key(self, private, peer_share, frm_supported_):
def calc_shared_key(self, private, peer_share, frm_supported=None):
"""Calcualte the shared key given our private and remote share value"""
raise NotImplementedError("Abstract class")

Expand Down Expand Up @@ -957,6 +957,7 @@ def get_random_private_key(self):
def calc_public_value(self, private, frm_negotiated=None):
"""
Calculate the public value for given private value.
Frm_negotiated added for API compatibility, not needed for FFDH.
:rtype: int
"""
Expand All @@ -978,8 +979,11 @@ def _normalise_peer_share(self, peer_share):
"Key share does not match FFDH prime")
return bytesToNumber(peer_share)

def calc_shared_key(self, private, peer_share, frm_supported_=None):
"""Calculate the shared key."""
def calc_shared_key(self, private, peer_share, frm_supported=None):
"""Calculate the shared key.
Frm_supported added for API compatibility, not needed for FFDH.
:rtype: bytearray"""
peer_share = self._normalise_peer_share(peer_share)
# First half of RFC 2631, Section 2.1.5. Validate the client's public
# key.
Expand All @@ -998,7 +1002,6 @@ def calc_shared_key(self, private, peer_share, frm_supported_=None):

class ECDHKeyExchange(RawDHKeyExchange):
"""Implementation of the Elliptic Curve Diffie-Hellman key exchange."""

_x_groups = set((GroupName.x25519, GroupName.x448))

@staticmethod
Expand Down Expand Up @@ -1036,29 +1039,51 @@ def _get_fun_gen_size(self):
return x448, bytearray(X448_G), X448_ORDER_SIZE

@staticmethod
def _get_ext_name(ext):

def _get_point_format(ext):
"""Get extension name from the numeric value."""
transform = {0: 'uncompressed', 1: 'compressed', 2: 'compressed'}
transform = {ECPointFormat.uncompressed: 'uncompressed',
ECPointFormat.ansiX962_compressed_char2: 'compressed',
ECPointFormat.ansiX962_compressed_prime: 'compressed'}
return transform[ext]

def calc_public_value(self, private, frm_negotiated=0):
def calc_public_value(self,
private,
frm_negotiated=ECPointFormat.uncompressed):
"""Calculate public value for given private key."""
extension = self._get_ext_name(frm_negotiated)
point_fmt = self._get_point_format(frm_negotiated)
if isinstance(private, ecdsa.keys.SigningKey):
return private.verifying_key.to_string(extension)
return private.verifying_key.to_string(point_fmt)
if self.group in self._x_groups:
fun, generator, _ = self._get_fun_gen_size()
return fun(private, generator)
else:
curve = getCurveByName(GroupName.toStr(self.group))
point = curve.generator * private
return bytearray(point.to_bytes(encoding=extension))

def calc_shared_key(self, private, peer_share, frm_supported_=set([0])):
"""Calculate the shared key,"""
frm_supported = set()
for ext in frm_supported_:
frm_supported.add(self._get_ext_name(ext))
return bytearray(point.to_bytes(encoding=point_fmt))

def calc_shared_key(self, private, peer_share,
frm_supported=set([ECPointFormat.uncompressed])):
"""Calculate the shared key.
:type private: bytearray | SigningKey
:param private: private value
:type peer_share: bytearray
:param peer_share: public value
:type frm_supported: set of ECPointFormat
:param frm_supported: supported point format extension for ec
:rtype: bytearray
:returns: shared key
:raises TLSIllegalParameterException
when the paramentrs for point are invalid
"""
valid_encodings = set([self._get_point_format(i) \
for i in frm_supported])

if self.group in self._x_groups:
fun, _, size = self._get_fun_gen_size()
Expand All @@ -1074,7 +1099,7 @@ def calc_shared_key(self, private, peer_share, frm_supported_=set([0])):
try:
abstractPoint = ecdsa.ellipticcurve.AbstractPoint()
point = abstractPoint.from_bytes(curve.curve, peer_share,
valid_encodings=frm_supported)
valid_encodings=valid_encodings)
ecdhYc = ecdsa.ellipticcurve.Point(
curve.curve, point[0], point[1])

Expand All @@ -1083,7 +1108,7 @@ def calc_shared_key(self, private, peer_share, frm_supported_=set([0])):
if isinstance(private, ecdsa.keys.SigningKey):
ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private)
ecdh.load_received_public_key_bytes(peer_share,
valid_encodings=frm_supported)
valid_encodings=valid_encodings)
return bytearray(ecdh.generate_sharedsecret_bytes())
S = ecdhYc * private

Expand Down
41 changes: 7 additions & 34 deletions tlslite/tlsconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,7 @@ def _clientSendClientHello(self, settings, session, srpUsername,
shares = []
for group_name in settings.keyShares:
group_id = getattr(GroupName, group_name)
key_share = self._genKeyShareEntry(group_id, (3, 4),
settings.ecPointFormats[0])
key_share = self._genKeyShareEntry(group_id, (3, 4))
shares.append(key_share)
# if TLS 1.3 is enabled, key_share must always be sent
# (unless only static PSK is used)
Expand All @@ -762,8 +761,9 @@ def _clientSendClientHello(self, settings, session, srpUsername,
if next((cipher for cipher in cipherSuites \
if cipher in CipherSuite.ecdhAllSuites), None) is not None:
groups.extend(self._curveNamesToList(settings))
extensions.append(ECPointFormatsExtension().\
create(settings.ecPointFormats))
if settings.ec_point_formats:
extensions.append(ECPointFormatsExtension().\
create(settings.ec_point_formats))
# Advertise FFDHE groups if we have DHE ciphers
if next((cipher for cipher in cipherSuites
if cipher in CipherSuite.dhAllSuites), None) is not None:
Expand Down Expand Up @@ -975,16 +975,7 @@ def _clientGetServerHello(self, settings, session, clientHello):
"did sent the key share "
"for"):
yield result
ext_negotiated = 0
ext_c = clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = hello_retry.getExtension(ExtensionType.ec_point_formats)
if ext_c and ext_s:
ext_negotiated = None
for ext in ext_c.formats:
if ext in ext_s.formats and ext_negotiated is None:
ext_negotiated = ext
key_share = self._genKeyShareEntry(group_id, (3, 4),
ext_negotiated)
key_share = self._genKeyShareEntry(group_id, (3, 4))

# old key shares need to be removed
cl_key_share_ext.client_shares = [key_share]
Expand Down Expand Up @@ -2288,7 +2279,7 @@ def _handshakeServerAsyncHelper(self, verifierDB,
# even though the selected cipher may not use ECC, client may want
# to send a CA certificate with ECDSA...
extensions.append(ECPointFormatsExtension().
create(settings.ecPointFormats))
create(settings.ec_point_formats))

# if client sent Heartbeat extension
if clientHello.getExtension(ExtensionType.heartbeat):
Expand Down Expand Up @@ -2726,25 +2717,7 @@ def _serverTLS13Handshake(self, settings, clientHello, cipherSuite,
(psk is None and privateKey):
self.ecdhCurve = selected_group
kex = self._getKEX(selected_group, version)

ext_negotiated = 0
ext_supported = set([0])
ext_c = clientHello.getExtension(ExtensionType.ec_point_formats)
if ext_c and settings.ecPointFormats:
ext_negotiated = None
ext_supported = set()
for ext in ext_c.formats:
if ext in settings.ecPointFormats:
ext_supported.add(ext)
if ext_negotiated is None:
ext_negotiated = ext
if len(ext_supported) == 0:
raise TLSHandshakeFailure(
"No negotiated point extension")

key_share = self._genKeyShareEntry(selected_group,
version,
ext_negotiated)
key_share = self._genKeyShareEntry(selected_group, version)
try:
shared_sec = kex.calc_shared_key(key_share.private,
cl_key_share.key_exchange,
Expand Down

0 comments on commit dd118da

Please sign in to comment.