From 31607fab77bdcdd63d18e7afd81b63654145945b Mon Sep 17 00:00:00 2001 From: AntonKueltz Date: Wed, 19 Feb 2020 22:10:03 -0800 Subject: [PATCH] fix: add type validation for point arithmetic --- CHANGELOG.md | 1 + docs/conf.py | 4 ++-- fastecdsa/curve.py | 6 ++--- fastecdsa/point.py | 43 +++++++++++++++++---------------- fastecdsa/tests/test_point.py | 45 +++++++++++++++++++++++++++++++++++ fastecdsa/util.py | 18 +++++++++++++- 6 files changed, 91 insertions(+), 26 deletions(-) create mode 100644 fastecdsa/tests/test_point.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ee243b..17781b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ ### Fixed - Curves with no OID are not added to the lookup by OID map +- Type validation for operations of points ## [2.0.0] ### Added diff --git a/docs/conf.py b/docs/conf.py index be0414c..2715db6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,9 +62,9 @@ # built documents. # # The short X.Y version. -version = '2.0' +version = '2.1' # The full version, including alpha/beta/rc tags. -release = '2.0.0' +release = '2.1.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/fastecdsa/curve.py b/fastecdsa/curve.py index 251a169..3439fe3 100644 --- a/fastecdsa/curve.py +++ b/fastecdsa/curve.py @@ -55,7 +55,7 @@ def get_curve_by_oid(cls, oid: bytes): """Get a curve via it's object identifier.""" return cls._oid_lookup.get(oid, None) - def is_point_on_curve(self, P) -> bool: + def is_point_on_curve(self, point: (int, int)) -> bool: """ Check if a point lies on this curve. The check is done by evaluating the curve equation :math:`y^2 \equiv x^3 + ax + b \pmod{p}` @@ -63,13 +63,13 @@ def is_point_on_curve(self, P) -> bool: the congruence holds, then the point lies on this curve. Args: - P (long, long): A tuple representing the point :math:`P` as an :math:`(x, y)` coordinate + point (long, long): A tuple representing the point :math:`P` as an :math:`(x, y)` coordinate pair. Returns: bool: :code:`True` if the point lies on this curve, otherwise :code:`False`. """ - x, y, = P[0], P[1] + x, y, = point left = y * y right = (x * x * x) + (self.a * x) + self.b return (left - right) % self.p == 0 diff --git a/fastecdsa/point.py b/fastecdsa/point.py index 40d8e90..8636518 100644 --- a/fastecdsa/point.py +++ b/fastecdsa/point.py @@ -1,5 +1,6 @@ from fastecdsa import curvemath from .curve import P256 +from .util import validate_type class CurveMismatchError(Exception): @@ -49,6 +50,7 @@ def __repr__(self) -> str: return self.__str__() def __eq__(self, other) -> bool: + validate_type(other, Point) return self.x == other.x and self.y == other.y and self.curve is other.curve def __add__(self, other): @@ -61,6 +63,8 @@ def __add__(self, other): Returns: :class:`Point`: A point :math:`R` such that :math:`R = P + Q` """ + validate_type(other, Point) + if self == self.IDENTITY_ELEMENT: return other elif other == self.IDENTITY_ELEMENT: @@ -92,8 +96,9 @@ def __radd__(self, other): | other (:class:`Point`): a point :math:`Q` on the curve Returns: - :class:`Point`: A point :math:`R` such that :math:`R = P + Q` + :class:`Point`: A point :math:`R` such that :math:`R = R + Q` """ + validate_type(other, Point) return self.__add__(other) def __sub__(self, other): @@ -106,6 +111,8 @@ def __sub__(self, other): Returns: :class:`Point`: A point :math:`R` such that :math:`R = P - Q` """ + validate_type(other, Point) + if self == other: return self.IDENTITY_ELEMENT elif other == self.IDENTITY_ELEMENT: @@ -125,26 +132,22 @@ def __mul__(self, scalar: int): Returns: :class:`Point`: A point :math:`R` such that :math:`R = P * d` """ - try: - d = int(scalar) % self.curve.q - except ValueError: - raise TypeError('Curve point multiplication must be by an integer') - else: - if d == 0: - return self.IDENTITY_ELEMENT + validate_type(scalar, int) + if scalar == 0: + return self.IDENTITY_ELEMENT - x, y = curvemath.mul( - str(self.x), - str(self.y), - str(d), - str(self.curve.p), - str(self.curve.a), - str(self.curve.b), - str(self.curve.q), - str(self.curve.gx), - str(self.curve.gy) - ) - return Point(int(x), int(y), self.curve) + x, y = curvemath.mul( + str(self.x), + str(self.y), + str(scalar), + str(self.curve.p), + str(self.curve.a), + str(self.curve.b), + str(self.curve.q), + str(self.curve.gx), + str(self.curve.gy) + ) + return Point(int(x), int(y), self.curve) def __rmul__(self, scalar: int): """Multiply a :class:`Point` on an elliptic curve by an integer. diff --git a/fastecdsa/tests/test_point.py b/fastecdsa/tests/test_point.py new file mode 100644 index 0000000..71f0f86 --- /dev/null +++ b/fastecdsa/tests/test_point.py @@ -0,0 +1,45 @@ +from unittest import TestCase + +from ..curve import W25519 +from ..point import Point + + +class TestTypeValidation(TestCase): + def test_type_validation_add(self): + with self.assertRaises(TypeError): + _ = Point.IDENTITY_ELEMENT + 2 + + with self.assertRaises(TypeError): + _ = W25519.G + 2 + + with self.assertRaises(TypeError): + _ = 2 + Point.IDENTITY_ELEMENT + + with self.assertRaises(TypeError): + _ = 2 + W25519.G + + def test_type_validation_sub(self): + with self.assertRaises(TypeError): + _ = Point.IDENTITY_ELEMENT - 2 + + with self.assertRaises(TypeError): + _ = W25519.G - 2 + + with self.assertRaises(TypeError): + _ = 2 - Point.IDENTITY_ELEMENT + + with self.assertRaises(TypeError): + _ = 2 - W25519.G + + def test_type_validation_mul(self): + with self.assertRaises(TypeError): + _ = Point.IDENTITY_ELEMENT * 1.5 + + with self.assertRaises(TypeError): + _ = W25519.G * 1.5 + + with self.assertRaises(TypeError): + _ = 1.5 * Point.IDENTITY_ELEMENT + + with self.assertRaises(TypeError): + _ = 1.5 * W25519.G diff --git a/fastecdsa/util.py b/fastecdsa/util.py index dd115b7..923eef8 100644 --- a/fastecdsa/util.py +++ b/fastecdsa/util.py @@ -82,7 +82,7 @@ def gen_nonce(self): v = hmac.new(k, v, self.hashfunc).digest() -def _tonelli_shanks(n: int, p: int) -> int: +def _tonelli_shanks(n: int, p: int) -> (int, int): """A generic algorithm for computng modular square roots.""" Q, S = p - 1, 0 while Q % 2 == 0: @@ -147,3 +147,19 @@ def msg_bytes(msg) -> bytes: else: raise ValueError('Msg "{}" of type {} cannot be converted to bytes'.format( msg, type(msg))) + + +def validate_type(instance: object, expected_type: type): + """Validate that instance is an instance of the expected_type. + + Args: + | instance: The object whose type is being checked + | expected_type: The expected type of instance + | var_name: The name of the object + + Raises: + TypeError: If instance is not of type expected_type + """ + if not isinstance(instance, expected_type): + raise TypeError('Expected a value of type {}, got a value of type {}'.format( + expected_type, type(instance)))