Skip to content

Commit

Permalink
Add support for scalar addition, subtraction, and negation
Browse files Browse the repository at this point in the history
  • Loading branch information
frederickjansen committed May 6, 2024
1 parent 03a59af commit 46832c4
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 5 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "oblivious"
version = "7.0.0"
version = "7.1.0"
description = """\
Python library that serves as an API for common \
cryptographic primitives used to implement OPRF, OT, \
Expand Down Expand Up @@ -33,8 +33,8 @@ mclbn256 = [
]
docs = [
"toml~=0.10.2",
"sphinx~=4.2.0",
"sphinx-rtd-theme~=1.0.0"
"sphinx~=5.3.0",
"sphinx-rtd-theme~=1.2.0"
]
test = [
"pytest~=7.2",
Expand Down
138 changes: 136 additions & 2 deletions src/oblivious/ristretto.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ class python:
:obj:`python.pnt <pnt>`, :obj:`python.bas <bas>`,
:obj:`python.can <can>`, :obj:`python.mul <mul>`,
:obj:`python.add <add>`, :obj:`python.sub <sub>`,
:obj:`python.neg <neg>`,
:obj:`python.neg <neg>`, :obj:`python.sad <sad>`,
:obj:`python.ssu <ssu>`,
:obj:`python.point <oblivious.ristretto.python.point>`, and
:obj:`python.scalar <oblivious.ristretto.python.scalar>`.
For example, you can perform addition of points using
Expand Down Expand Up @@ -385,6 +386,54 @@ def smu(s: bytes, t: bytes) -> bytes:
"""
return _sc25519_mul(s, t)

@staticmethod
def sad(s: bytes, t: bytes) -> bytes:
"""
Return the sum of two scalars.
>>> p = scalar.from_int(4)
>>> q = scalar.from_int(2)
>>> sodium.sad(p, q) == sodium.sad(q, p)
True
>>> sodium.sad(p, q).hex()
'0600000000000000000000000000000000000000000000000000000000000000'
>>> sodium.sad(-q, p).hex()
'0200000000000000000000000000000000000000000000000000000000000000'
"""
(s, t) = (int.from_bytes(s, 'little'), int.from_bytes(t, 'little'))
return (
(s + t) % (pow(2, 252) + 27742317777372353535851937790883648493)
).to_bytes(32, 'little')

@staticmethod
def ssu(s: bytes, t: bytes) -> bytes:
"""
Return the result of subtracting the right-hand scalar from the
left-hand scalar.
>>> p = scalar.from_int(4)
>>> q = scalar.from_int(2)
>>> sodium.ssu(p, q).hex()
'0200000000000000000000000000000000000000000000000000000000000000'
"""
(s, t) = (int.from_bytes(s, 'little'), int.from_bytes(t, 'little'))
return (
(s - t) % (pow(2, 252) + 27742317777372353535851937790883648493)
).to_bytes(32, 'little')

@staticmethod
def sne(s: bytes) -> bytes:
"""
Return the additive inverse of a scalar.
>>> p = scalar.from_int(4)
>>> sodium.sne(p).hex()
'e9d3f55c1a631258d69cf7a2def9de1400000000000000000000000000000010'
"""
return (
(pow(2, 252) + 27742317777372353535851937790883648493 - int.from_bytes(s, 'little'))
).to_bytes(32, 'little')

#
# Attempt to load primitives from libsodium, if it is present;
# otherwise, use the rbcl library, if it is present. Otherwise,
Expand Down Expand Up @@ -527,7 +576,7 @@ class encapsulates shared/dynamic library variants of both classes
:obj:`sodium.pnt <pnt>`, :obj:`sodium.bas <bas>`,
:obj:`sodium.can <can>`, :obj:`sodium.mul <mul>`,
:obj:`sodium.add <add>`, :obj:`sodium.sub <sub>`,
:obj:`sodium.neg <neg>`,
:obj:`sodium.neg <neg>`, :obj:`sodium.sne <sne>`,
:obj:`sodium.point <oblivious.ristretto.sodium.point>`, and
:obj:`sodium.scalar <oblivious.ristretto.sodium.scalar>`.
For example, you can perform addition of points using
Expand Down Expand Up @@ -734,6 +783,54 @@ def smu(s: bytes, t: bytes) -> bytes:
bytes(s), bytes(t)
)

@staticmethod
def sad(s: bytes, t: bytes) -> bytes:
"""
Return the sum of two scalars.
>>> s = sodium.scl()
>>> t = sodium.scl()
>>> sodium.sad(s, t) == sodium.sad(t, s)
True
"""
return sodium._call(
sodium._lib.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_scalar_add,
bytes(s), bytes(t)
)

@staticmethod
def ssu(p: bytes, q: bytes) -> bytes:
"""
Return the result of subtracting the right-hand scalar from the
left-hand scalar.
>>> p = scalar.hash('123'.encode())
>>> q = scalar.hash('456'.encode())
>>> sodium.ssu(p, q).hex()
'd08dcedb3a8dc87951acd91334a1faed511f49c6e9296780634b858e42347908'
"""
return sodium._call(
_sodium.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_scalar_sub,
bytes(p), bytes(q)
)

@staticmethod
def sne(s: bytes) -> bytes:
"""
Return the additive inverse of a scalar.
>>> p = scalar.from_int(4)
>>> sodium.sne(p).hex()
'e9d3f55c1a631258d69cf7a2def9de1400000000000000000000000000000010'
"""
return sodium._call(
_sodium.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_scalar_negate,
bytes(s)
)

except: # pylint: disable=W0702 # pragma: no cover
# Exported symbol.
sodium = None # pragma: no cover
Expand Down Expand Up @@ -1091,6 +1188,17 @@ def __invert__(self: scalar) -> scalar:

return self._implementation.scalar(self._implementation.inv(self))

def __neg__(self: scalar) -> scalar:
"""
Return the negation of this instance.
>>> p = scalar.hash('123'.encode())
>>> q = scalar.hash('456'.encode())
>>> ((p + q) + (-q)) == p
True
"""
return self._implementation.scalar(self._implementation.sne(self))

def __mul__(self: scalar, other: Union[scalar, point]) -> Union[scalar, point]:
"""
Multiply the supplied scalar or point by this instance.
Expand Down Expand Up @@ -1155,6 +1263,32 @@ def __rmul__(self: scalar, other: Union[scalar, point]):
'scalar must be on left-hand side of multiplication operator'
)

def __add__(self: scalar, other: scalar) -> scalar:
"""
Return the sum of this instance and another scalar.
>>> p = scalar.hash('123'.encode())
>>> q = scalar.hash('456'.encode())
>>> (p + q).hex()
'69117034205aa81808edae5d89128497ef75f5b71416d97ccfd18760ad117c0e'
>>> p + (q - q) == p
True
"""
return self._implementation.scalar(self._implementation.sad(self, other))

def __sub__(self: scalar, other: scalar) -> scalar:
"""
Return the result of subtracting another scalar from this instance.
>>> p = scalar.hash('123'.encode())
>>> q = scalar.hash('456'.encode())
>>> (p - q).hex()
'd08dcedb3a8dc87951acd91334a1faed511f49c6e9296780634b858e42347908'
>>> p - p == scalar.from_int(0)
True
"""
return self._implementation.scalar(self._implementation.ssu(self, other))

def to_bytes(self: scalar) -> bytes:
"""
Return the bytes-like object that represents this instance.
Expand Down
32 changes: 32 additions & 0 deletions test/test_ristretto.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,20 @@ def test_types_scalar_mul_point(self):
sodium_hidden_and_fallback(hidden, fallback)
self.assertTrue(isinstance(cls.scalar() * cls.point(), cls.point))

def test_types_scalar_add(self):
sodium_hidden_and_fallback(hidden, fallback)
(s0, s1) = (cls.scalar.random(), cls.scalar.random())
self.assertTrue(isinstance(s0 + s1, cls.scalar))

def test_types_scalar_sub(self):
sodium_hidden_and_fallback(hidden, fallback)
(s0, s1) = (cls.scalar.random(), cls.scalar.random())
self.assertTrue(isinstance(s0 - s1, cls.scalar))

def test_types_scalar_neg(self):
sodium_hidden_and_fallback(hidden, fallback)
self.assertTrue(isinstance(-cls.scalar.random(), cls.scalar))

class Test_algebra(TestCase):
"""
Tests of algebraic properties of primitive operations and class methods.
Expand Down Expand Up @@ -689,6 +703,24 @@ def test_algebra_scalar_mul_point_on_left_hand_side(self):
p = cls.point.hash(bytes(POINT_LEN))
self.assertRaises(TypeError, lambda: p * s)

def test_algebra_scalar_add_commute(self):
sodium_hidden_and_fallback(hidden, fallback)
for bs in fountains(SCALAR_LEN + SCALAR_LEN, limit=TRIALS_PER_TEST):
(s0, s1) = (
cls.scalar.hash(bs[:SCALAR_LEN]),
cls.scalar.hash(bs[SCALAR_LEN:])
)
self.assertEqual(cls.sad(s0, s1), cls.sad(s1, s0))

def test_algebra_scalar_add_neg_add_identity(self):
sodium_hidden_and_fallback(hidden, fallback)
for bs in fountains(SCALAR_LEN + SCALAR_LEN, limit=TRIALS_PER_TEST):
(s0, s1) = (
cls.scalar.hash(bs[:SCALAR_LEN]),
cls.scalar.hash(bs[SCALAR_LEN:])
)
self.assertEqual(cls.sad(cls.sad(s0, cls.sne(s0)), s1), s1)

return (
Test_primitives,
Test_classes,
Expand Down

0 comments on commit 46832c4

Please sign in to comment.