Skip to content

Commit

Permalink
fix: add type validation for point arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonKueltz committed Feb 20, 2020
1 parent bb65064 commit 31607fa
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

### Fixed
- Curves with no OID are not added to the lookup by OID map
- Type validation for operations of points

## [2.0.0]
### Added
Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@
# built documents.
#
# The short X.Y version.
version = '2.0'
version = '2.1'
# The full version, including alpha/beta/rc tags.
release = '2.0.0'
release = '2.1.0'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
6 changes: 3 additions & 3 deletions fastecdsa/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,21 @@ def get_curve_by_oid(cls, oid: bytes):
"""Get a curve via it's object identifier."""
return cls._oid_lookup.get(oid, None)

def is_point_on_curve(self, P) -> bool:
def is_point_on_curve(self, point: (int, int)) -> bool:
""" Check if a point lies on this curve.
The check is done by evaluating the curve equation :math:`y^2 \equiv x^3 + ax + b \pmod{p}`
at the given point :math:`(x,y)` with this curve's domain parameters :math:`(a, b, p)`. If
the congruence holds, then the point lies on this curve.
Args:
P (long, long): A tuple representing the point :math:`P` as an :math:`(x, y)` coordinate
point (long, long): A tuple representing the point :math:`P` as an :math:`(x, y)` coordinate
pair.
Returns:
bool: :code:`True` if the point lies on this curve, otherwise :code:`False`.
"""
x, y, = P[0], P[1]
x, y, = point
left = y * y
right = (x * x * x) + (self.a * x) + self.b
return (left - right) % self.p == 0
Expand Down
43 changes: 23 additions & 20 deletions fastecdsa/point.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastecdsa import curvemath
from .curve import P256
from .util import validate_type


class CurveMismatchError(Exception):
Expand Down Expand Up @@ -49,6 +50,7 @@ def __repr__(self) -> str:
return self.__str__()

def __eq__(self, other) -> bool:
validate_type(other, Point)
return self.x == other.x and self.y == other.y and self.curve is other.curve

def __add__(self, other):
Expand All @@ -61,6 +63,8 @@ def __add__(self, other):
Returns:
:class:`Point`: A point :math:`R` such that :math:`R = P + Q`
"""
validate_type(other, Point)

if self == self.IDENTITY_ELEMENT:
return other
elif other == self.IDENTITY_ELEMENT:
Expand Down Expand Up @@ -92,8 +96,9 @@ def __radd__(self, other):
| other (:class:`Point`): a point :math:`Q` on the curve
Returns:
:class:`Point`: A point :math:`R` such that :math:`R = P + Q`
:class:`Point`: A point :math:`R` such that :math:`R = R + Q`
"""
validate_type(other, Point)
return self.__add__(other)

def __sub__(self, other):
Expand All @@ -106,6 +111,8 @@ def __sub__(self, other):
Returns:
:class:`Point`: A point :math:`R` such that :math:`R = P - Q`
"""
validate_type(other, Point)

if self == other:
return self.IDENTITY_ELEMENT
elif other == self.IDENTITY_ELEMENT:
Expand All @@ -125,26 +132,22 @@ def __mul__(self, scalar: int):
Returns:
:class:`Point`: A point :math:`R` such that :math:`R = P * d`
"""
try:
d = int(scalar) % self.curve.q
except ValueError:
raise TypeError('Curve point multiplication must be by an integer')
else:
if d == 0:
return self.IDENTITY_ELEMENT
validate_type(scalar, int)
if scalar == 0:
return self.IDENTITY_ELEMENT

x, y = curvemath.mul(
str(self.x),
str(self.y),
str(d),
str(self.curve.p),
str(self.curve.a),
str(self.curve.b),
str(self.curve.q),
str(self.curve.gx),
str(self.curve.gy)
)
return Point(int(x), int(y), self.curve)
x, y = curvemath.mul(
str(self.x),
str(self.y),
str(scalar),
str(self.curve.p),
str(self.curve.a),
str(self.curve.b),
str(self.curve.q),
str(self.curve.gx),
str(self.curve.gy)
)
return Point(int(x), int(y), self.curve)

def __rmul__(self, scalar: int):
"""Multiply a :class:`Point` on an elliptic curve by an integer.
Expand Down
45 changes: 45 additions & 0 deletions fastecdsa/tests/test_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from unittest import TestCase

from ..curve import W25519
from ..point import Point


class TestTypeValidation(TestCase):
def test_type_validation_add(self):
with self.assertRaises(TypeError):
_ = Point.IDENTITY_ELEMENT + 2

with self.assertRaises(TypeError):
_ = W25519.G + 2

with self.assertRaises(TypeError):
_ = 2 + Point.IDENTITY_ELEMENT

with self.assertRaises(TypeError):
_ = 2 + W25519.G

def test_type_validation_sub(self):
with self.assertRaises(TypeError):
_ = Point.IDENTITY_ELEMENT - 2

with self.assertRaises(TypeError):
_ = W25519.G - 2

with self.assertRaises(TypeError):
_ = 2 - Point.IDENTITY_ELEMENT

with self.assertRaises(TypeError):
_ = 2 - W25519.G

def test_type_validation_mul(self):
with self.assertRaises(TypeError):
_ = Point.IDENTITY_ELEMENT * 1.5

with self.assertRaises(TypeError):
_ = W25519.G * 1.5

with self.assertRaises(TypeError):
_ = 1.5 * Point.IDENTITY_ELEMENT

with self.assertRaises(TypeError):
_ = 1.5 * W25519.G
18 changes: 17 additions & 1 deletion fastecdsa/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def gen_nonce(self):
v = hmac.new(k, v, self.hashfunc).digest()


def _tonelli_shanks(n: int, p: int) -> int:
def _tonelli_shanks(n: int, p: int) -> (int, int):
"""A generic algorithm for computng modular square roots."""
Q, S = p - 1, 0
while Q % 2 == 0:
Expand Down Expand Up @@ -147,3 +147,19 @@ def msg_bytes(msg) -> bytes:
else:
raise ValueError('Msg "{}" of type {} cannot be converted to bytes'.format(
msg, type(msg)))


def validate_type(instance: object, expected_type: type):
"""Validate that instance is an instance of the expected_type.
Args:
| instance: The object whose type is being checked
| expected_type: The expected type of instance
| var_name: The name of the object
Raises:
TypeError: If instance is not of type expected_type
"""
if not isinstance(instance, expected_type):
raise TypeError('Expected a value of type {}, got a value of type {}'.format(
expected_type, type(instance)))

0 comments on commit 31607fa

Please sign in to comment.