diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f40e352..109f058 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -120,8 +120,8 @@ jobs: SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} with: args: > - -Dsonar.projectKey=bytemare_crypto - -Dsonar.organization=bytemare-github + -Dsonar.projectKey=crypto + -Dsonar.organization=bytemare -Dsonar.go.coverage.reportPaths=.github/coverage.out -Dsonar.sources=. -Dsonar.test.exclusions=tests/** diff --git a/README.md b/README.md index dfe9c0d..833896f 100644 --- a/README.md +++ b/README.md @@ -68,10 +68,12 @@ type Scalar interface { LessOrEqual(Scalar) int IsZero() bool Set(Scalar) Scalar - SetInt(big.Int) error + SetUInt64(uint64) Scalar Copy() Scalar Encode() []byte Decode(in []byte) error + Hex() string + HexDecode([]byte) error encoding.BinaryMarshaler encoding.BinaryUnmarshaler } @@ -95,6 +97,8 @@ type Element interface { Encode() []byte XCoordinate() []byte Decode(data []byte) error + Hex() string + HexDecode([]byte) error encoding.BinaryMarshaler encoding.BinaryUnmarshaler } diff --git a/element.go b/element.go index 1e144d6..c22b5ae 100644 --- a/element.go +++ b/element.go @@ -132,6 +132,20 @@ func (e *Element) Decode(data []byte) error { return nil } +// Hex returns the fixed-sized hexadecimal encoding of e. +func (e *Element) Hex() string { + return e.Element.Hex() +} + +// DecodeHex sets e to the decoding of the hex encoded element. +func (e *Element) DecodeHex(h string) error { + if err := e.Element.DecodeHex(h); err != nil { + return fmt.Errorf("element DecodeHex: %w", err) + } + + return nil +} + // MarshalBinary returns the compressed byte encoding of the element. func (e *Element) MarshalBinary() ([]byte, error) { dec, err := e.Element.MarshalBinary() diff --git a/go.mod b/go.mod index c79dbf6..5ec10b5 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( filippo.io/edwards25519 v1.1.0 filippo.io/nistec v0.0.3 github.com/bytemare/hash2curve v0.3.0 - github.com/bytemare/secp256k1 v0.1.2 + github.com/bytemare/secp256k1 v0.1.4 github.com/gtank/ristretto255 v0.1.2 ) diff --git a/go.sum b/go.sum index 8cf3e82..e560c3c 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ github.com/bytemare/hash v0.3.0 h1:RqFMt3mqpF7UxLdjBrsOZm/2cz0cQiAOnYc9gDLopWE= github.com/bytemare/hash v0.3.0/go.mod h1:YKOBchL0l8hRLFinVCL8YUKokGNIMhrWEHPHo3EV7/M= github.com/bytemare/hash2curve v0.3.0 h1:41Npcbc+u/E252A5aCMtxDcz7JPkkX1QzShneTFm4eg= github.com/bytemare/hash2curve v0.3.0/go.mod h1:itj45U8uqvCtWC0eCswIHVHswXcEHkpFui7gfJdPSfQ= -github.com/bytemare/secp256k1 v0.1.2 h1:aM+p/+0y1h0SZWqS/yzjGPzffVFubJvwLjUgodFEWOo= -github.com/bytemare/secp256k1 v0.1.2/go.mod h1:Pxb9miDs8PTt5mOktvvXiRflvLxI1wdxbXrc6IYsaho= +github.com/bytemare/secp256k1 v0.1.4 h1:6F1yP6RiUiWwH7AsGHsHktmHm24QcetdDcc39roBd2M= +github.com/bytemare/secp256k1 v0.1.4/go.mod h1:Pxb9miDs8PTt5mOktvvXiRflvLxI1wdxbXrc6IYsaho= github.com/gtank/ristretto255 v0.1.2 h1:JEqUCPA1NvLq5DwYtuzigd7ss8fwbYay9fi4/5uMzcc= github.com/gtank/ristretto255 v0.1.2/go.mod h1:Ph5OpO6c7xKUGROZfWVLiJf9icMDwUeIvY4OmlYW69o= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= diff --git a/internal/edwards25519/element.go b/internal/edwards25519/element.go index f279aab..2ac2da1 100644 --- a/internal/edwards25519/element.go +++ b/internal/edwards25519/element.go @@ -9,6 +9,7 @@ package edwards25519 import ( + "encoding/hex" "fmt" ed "filippo.io/edwards25519" @@ -163,6 +164,21 @@ func (e *Element) Decode(data []byte) error { return nil } +// Hex returns the fixed-sized hexadecimal encoding of e. +func (e *Element) Hex() string { + return hex.EncodeToString(e.Encode()) +} + +// DecodeHex sets e to the decoding of the hex encoded element. +func (e *Element) DecodeHex(h string) error { + b, err := hex.DecodeString(h) + if err != nil { + return fmt.Errorf("%w", err) + } + + return e.Decode(b) +} + // MarshalBinary returns the compressed byte encoding of the element. func (e *Element) MarshalBinary() (data []byte, err error) { return e.Encode(), nil diff --git a/internal/edwards25519/scalar.go b/internal/edwards25519/scalar.go index 8a40af2..4328263 100644 --- a/internal/edwards25519/scalar.go +++ b/internal/edwards25519/scalar.go @@ -9,6 +9,8 @@ package edwards25519 import ( + "encoding/binary" + "encoding/hex" "fmt" "math/big" @@ -268,18 +270,20 @@ func (s *Scalar) Set(scalar internal.Scalar) internal.Scalar { return s } -// SetInt sets s to i modulo the field order, and returns an error if one occurs. -func (s *Scalar) SetInt(i *big.Int) error { - a := new(big.Int).Set(i) +// SetUInt64 sets s to i modulo the field order, and returns an error if one occurs. +func (s *Scalar) SetUInt64(i uint64) internal.Scalar { + encoded := make([]byte, canonicalEncodingLength) + binary.LittleEndian.PutUint64(encoded, i) - bytes := make([]byte, 32) - bytes = a.Mod(a, &order).FillBytes(bytes) - - for j, k := 0, len(bytes)-1; j < k; j, k = j+1, k-1 { - bytes[j], bytes[k] = bytes[k], bytes[j] + sc, err := decodeScalar(encoded) + if err != nil { + // This cannot happen, since any uint64 is smaller than the order. + panic(fmt.Sprintf("unexpected decoding of uint64 scalar: %s", err)) } - return s.Decode(bytes) + s.set(sc) + + return s } func (s *Scalar) copy() *Scalar { @@ -325,6 +329,21 @@ func (s *Scalar) Decode(in []byte) error { return nil } +// Hex returns the fixed-sized hexadecimal encoding of s. +func (s *Scalar) Hex() string { + return hex.EncodeToString(s.Encode()) +} + +// DecodeHex sets s to the decoding of the hex encoded scalar. +func (s *Scalar) DecodeHex(h string) error { + b, err := hex.DecodeString(h) + if err != nil { + return fmt.Errorf("%w", err) + } + + return s.Decode(b) +} + // MarshalBinary returns the compressed byte encoding of the scalar. func (s *Scalar) MarshalBinary() (data []byte, err error) { return s.Encode(), nil diff --git a/internal/element.go b/internal/element.go index b020282..6b7e72b 100644 --- a/internal/element.go +++ b/internal/element.go @@ -9,7 +9,9 @@ // Package internal defines simple and abstract APIs to group Elements and Scalars. package internal -import "encoding" +import ( + "encoding" +) // Element interface abstracts common operations on an Element in a prime-order Group. type Element interface { @@ -55,6 +57,12 @@ type Element interface { // Decode sets the receiver to a decoding of the input data, and returns an error on failure. Decode(data []byte) error + // Hex returns the fixed-sized hexadecimal encoding of e. + Hex() string + + // DecodeHex sets e to the decoding of the hex encoded element. + DecodeHex(h string) error + // BinaryMarshaler implementation. encoding.BinaryMarshaler diff --git a/internal/misc.go b/internal/misc.go index 670fca9..171da7f 100644 --- a/internal/misc.go +++ b/internal/misc.go @@ -10,6 +10,7 @@ package internal import ( cryptorand "crypto/rand" + "encoding" "errors" "fmt" ) @@ -52,6 +53,30 @@ var ( ErrParamScalarInvalidEncoding = errors.New("invalid scalar encoding") ) +// An Encoder can encode itself to machine or human-readable forms. +type Encoder interface { + // Encode returns the compressed byte encoding. + Encode() []byte + + // Hex returns the fixed-sized hexadecimal encoding. + Hex() string + + // BinaryMarshaler implementation. + encoding.BinaryMarshaler +} + +// A Decoder can encode itself to machine or human-readable forms. +type Decoder interface { + // Decode sets the receiver to a decoding of the input data, and returns an error on failure. + Decode(data []byte) error + + // DecodeHex sets the receiver to the decoding of the hex encoded input. + DecodeHex(h string) error + + // BinaryUnmarshaler implementation. + encoding.BinaryUnmarshaler +} + // RandomBytes returns random bytes of length len (wrapper for crypto/rand). func RandomBytes(length int) []byte { random := make([]byte, length) diff --git a/internal/nist/element.go b/internal/nist/element.go index c282513..063fe77 100644 --- a/internal/nist/element.go +++ b/internal/nist/element.go @@ -10,6 +10,7 @@ package nist import ( "crypto/subtle" + "encoding/hex" "fmt" "github.com/bytemare/crypto/internal" @@ -214,12 +215,27 @@ func (e *Element[P]) XCoordinate() []byte { // Decode sets the receiver to a decoding of the input data, and returns an error on failure. func (e *Element[P]) Decode(data []byte) error { if _, err := e.p.SetBytes(data); err != nil { - return fmt.Errorf("nist element Decode: %w", err) + return fmt.Errorf("%w", err) } return nil } +// Hex returns the fixed-sized hexadecimal encoding of e. +func (e *Element[P]) Hex() string { + return hex.EncodeToString(e.Encode()) +} + +// DecodeHex sets e to the decoding of the hex encoded element. +func (e *Element[P]) DecodeHex(h string) error { + b, err := hex.DecodeString(h) + if err != nil { + return fmt.Errorf("%w", err) + } + + return e.Decode(b) +} + // MarshalBinary returns the compressed byte encoding of the element. func (e *Element[P]) MarshalBinary() ([]byte, error) { return e.Encode(), nil diff --git a/internal/nist/scalar.go b/internal/nist/scalar.go index 5f42498..38f56ff 100644 --- a/internal/nist/scalar.go +++ b/internal/nist/scalar.go @@ -10,6 +10,7 @@ package nist import ( "crypto/subtle" + "encoding/hex" "fmt" "math/big" @@ -112,7 +113,7 @@ func (s *Scalar) Pow(scalar internal.Scalar) internal.Scalar { return s.One() } - if scalar.Equal(scalar.Copy().One()) == 1 { + if scalar.Equal(newScalar(s.field).One()) == 1 { return s } @@ -181,12 +182,10 @@ func (s *Scalar) Set(scalar internal.Scalar) internal.Scalar { return s } -// SetInt sets s to i modulo the field order, and returns an error if one occurs. -func (s *Scalar) SetInt(i *big.Int) error { - s.scalar.Set(i) - s.field.Mod(&s.scalar) - - return nil +// SetUInt64 sets s to i modulo the field order, and returns an error if one occurs. +func (s *Scalar) SetUInt64(i uint64) internal.Scalar { + s.scalar.SetUint64(i) + return s } // Copy returns a copy of the Scalar. @@ -233,6 +232,21 @@ func (s *Scalar) Decode(in []byte) error { return nil } +// Hex returns the fixed-sized hexadecimal encoding of s. +func (s *Scalar) Hex() string { + return hex.EncodeToString(s.Encode()) +} + +// DecodeHex sets s to the decoding of the hex encoded scalar. +func (s *Scalar) DecodeHex(h string) error { + b, err := hex.DecodeString(h) + if err != nil { + return fmt.Errorf("%w", err) + } + + return s.Decode(b) +} + // MarshalBinary returns the compressed byte encoding of the scalar. func (s *Scalar) MarshalBinary() ([]byte, error) { return s.Encode(), nil diff --git a/internal/ristretto/element.go b/internal/ristretto/element.go index ab8895d..164d420 100644 --- a/internal/ristretto/element.go +++ b/internal/ristretto/element.go @@ -10,6 +10,7 @@ package ristretto import ( + "encoding/hex" "fmt" "github.com/gtank/ristretto255" @@ -169,6 +170,21 @@ func (e *Element) Decode(data []byte) error { return nil } +// Hex returns the fixed-sized hexadecimal encoding of e. +func (e *Element) Hex() string { + return hex.EncodeToString(e.Encode()) +} + +// DecodeHex sets e to the decoding of the hex encoded element. +func (e *Element) DecodeHex(h string) error { + b, err := hex.DecodeString(h) + if err != nil { + return fmt.Errorf("%w", err) + } + + return e.Decode(b) +} + // MarshalBinary returns the compressed byte encoding of the element. func (e *Element) MarshalBinary() ([]byte, error) { return e.Encode(), nil diff --git a/internal/ristretto/scalar.go b/internal/ristretto/scalar.go index c677542..a90422a 100644 --- a/internal/ristretto/scalar.go +++ b/internal/ristretto/scalar.go @@ -10,6 +10,8 @@ package ristretto import ( + "encoding/binary" + "encoding/hex" "fmt" "math/big" @@ -258,18 +260,20 @@ func (s *Scalar) Set(scalar internal.Scalar) internal.Scalar { return s } -// SetInt sets s to i modulo the field order, and returns an error if one occurs. -func (s *Scalar) SetInt(i *big.Int) error { - a := new(big.Int).Set(i) +// SetUInt64 sets s to i modulo the field order, and returns an error if one occurs. +func (s *Scalar) SetUInt64(i uint64) internal.Scalar { + encoded := make([]byte, canonicalEncodingLength) + binary.LittleEndian.PutUint64(encoded, i) - bytes := make([]byte, 32) - bytes = a.Mod(a, &order).FillBytes(bytes) - - for j, k := 0, len(bytes)-1; j < k; j, k = j+1, k-1 { - bytes[j], bytes[k] = bytes[k], bytes[j] + sc, err := decodeScalar(encoded) + if err != nil { + // This cannot happen, since any uint64 is smaller than the order. + panic(fmt.Sprintf("unexpected decoding of uint64 scalar: %s", err)) } - return s.Decode(bytes) + s.set(sc) + + return s } func (s *Scalar) copy() *Scalar { @@ -315,6 +319,21 @@ func (s *Scalar) Decode(in []byte) error { return nil } +// Hex returns the fixed-sized hexadecimal encoding of s. +func (s *Scalar) Hex() string { + return hex.EncodeToString(s.Encode()) +} + +// DecodeHex sets s to the decoding of the hex encoded scalar. +func (s *Scalar) DecodeHex(h string) error { + b, err := hex.DecodeString(h) + if err != nil { + return fmt.Errorf("%w", err) + } + + return s.Decode(b) +} + // MarshalBinary returns the compressed byte encoding of the scalar. func (s *Scalar) MarshalBinary() ([]byte, error) { return s.Encode(), nil diff --git a/internal/scalar.go b/internal/scalar.go index 00d6663..bedc2dc 100644 --- a/internal/scalar.go +++ b/internal/scalar.go @@ -11,7 +11,6 @@ package internal import ( "encoding" - "math/big" ) // Scalar interface abstracts common operations on scalars in a prime-order Group. @@ -53,8 +52,8 @@ type Scalar interface { // Set sets the receiver to the value of the argument scalar, and returns the receiver. Set(Scalar) Scalar - // SetInt sets s to i modulo the field order, and returns an error if one occurs. - SetInt(i *big.Int) error + // SetUInt64 sets s to i modulo the field order, and returns an error if one occurs. + SetUInt64(i uint64) Scalar // Copy returns a copy of the receiver. Copy() Scalar @@ -65,6 +64,12 @@ type Scalar interface { // Decode sets the receiver to a decoding of the input data, and returns an error on failure. Decode(in []byte) error + // Hex returns the fixed-sized hexadecimal encoding of s. + Hex() string + + // DecodeHex sets s to the decoding of the hex encoded scalar. + DecodeHex(h string) error + // BinaryMarshaler returns a byte representation of the element. encoding.BinaryMarshaler diff --git a/internal/secp256k1/element.go b/internal/secp256k1/element.go index bea9e95..bbb26aa 100644 --- a/internal/secp256k1/element.go +++ b/internal/secp256k1/element.go @@ -9,6 +9,7 @@ package secp256k1 import ( + "encoding/hex" "fmt" "github.com/bytemare/secp256k1" @@ -134,6 +135,21 @@ func (e *Element) Decode(data []byte) error { return nil } +// Hex returns the fixed-sized hexadecimal encoding of e. +func (e *Element) Hex() string { + return hex.EncodeToString(e.Encode()) +} + +// DecodeHex sets e to the decoding of the hex encoded element. +func (e *Element) DecodeHex(h string) error { + b, err := hex.DecodeString(h) + if err != nil { + return fmt.Errorf("%w", err) + } + + return e.Decode(b) +} + // MarshalBinary returns the compressed byte encoding of the element. func (e *Element) MarshalBinary() (data []byte, err error) { return e.Encode(), nil diff --git a/internal/secp256k1/scalar.go b/internal/secp256k1/scalar.go index 1db134c..2c2b5aa 100644 --- a/internal/secp256k1/scalar.go +++ b/internal/secp256k1/scalar.go @@ -10,7 +10,6 @@ package secp256k1 import ( "fmt" - "math/big" "github.com/bytemare/secp256k1" @@ -96,7 +95,7 @@ func (s *Scalar) Pow(scalar internal.Scalar) internal.Scalar { return s.One() } - if scalar.Equal(scalar.Copy().One()) == 1 { + if scalar.Equal(newScalar().One()) == 1 { return s } @@ -146,13 +145,10 @@ func (s *Scalar) Set(scalar internal.Scalar) internal.Scalar { return s } -// SetInt sets s to i modulo the field order, and returns an error if one occurs. -func (s *Scalar) SetInt(i *big.Int) error { - if err := s.scalar.SetInt(i); err != nil { - return fmt.Errorf("%w", err) - } - - return nil +// SetUInt64 sets s to i modulo the field order, and returns an error if one occurs. +func (s *Scalar) SetUInt64(i uint64) internal.Scalar { + s.scalar.SetUInt64(i) + return s } // Copy returns a copy of the receiver. @@ -178,6 +174,20 @@ func (s *Scalar) Decode(in []byte) error { return nil } +// Hex returns the fixed-sized hexadecimal encoding of s. +func (s *Scalar) Hex() string { + return s.scalar.Hex() +} + +// DecodeHex sets s to the decoding of the hex encoded scalar. +func (s *Scalar) DecodeHex(h string) error { + if err := s.scalar.DecodeHex(h); err != nil { + return fmt.Errorf("%w", err) + } + + return nil +} + // MarshalBinary returns the compressed byte encoding of the scalar. func (s *Scalar) MarshalBinary() (data []byte, err error) { return s.Encode(), nil diff --git a/scalar.go b/scalar.go index b228e28..59fb470 100644 --- a/scalar.go +++ b/scalar.go @@ -11,7 +11,6 @@ package crypto import ( "fmt" - "math/big" "github.com/bytemare/crypto/internal" ) @@ -129,13 +128,10 @@ func (s *Scalar) Set(scalar *Scalar) *Scalar { return s } -// SetInt sets s to i modulo the field order, and returns an error if one occurs. -func (s *Scalar) SetInt(i *big.Int) error { - if err := s.Scalar.SetInt(i); err != nil { - return fmt.Errorf("scalar: %w", err) - } - - return nil +// SetUInt64 sets s to i modulo the field order, and returns an error if one occurs. +func (s *Scalar) SetUInt64(i uint64) *Scalar { + s.Scalar.SetUInt64(i) + return s } // Copy returns a copy of the receiver. @@ -157,6 +153,20 @@ func (s *Scalar) Decode(data []byte) error { return nil } +// Hex returns the fixed-sized hexadecimal encoding of s. +func (s *Scalar) Hex() string { + return s.Scalar.Hex() +} + +// DecodeHex sets s to the decoding of the hex encoded scalar. +func (s *Scalar) DecodeHex(h string) error { + if err := s.Scalar.DecodeHex(h); err != nil { + return fmt.Errorf("scalar DecodeHex: %w", err) + } + + return nil +} + // MarshalBinary implements the encoding.BinaryMarshaler interface. func (s *Scalar) MarshalBinary() ([]byte, error) { dec, err := s.Scalar.MarshalBinary() diff --git a/tests/element_test.go b/tests/element_test.go index b0dc51b..4cd23a7 100644 --- a/tests/element_test.go +++ b/tests/element_test.go @@ -15,16 +15,15 @@ import ( "math/big" "testing" - "github.com/bytemare/secp256k1" - "github.com/bytemare/crypto" "github.com/bytemare/crypto/internal" ) const ( - errExpectedEquality = "expected equality" - errExpectedIdentity = "expected identity" - errWrongGroup = "wrong group" + errUnExpectedEquality = "unexpected equality" + errExpectedEquality = "expected equality" + errExpectedIdentity = "expected identity" + errWrongGroup = "wrong group" ) func testElementCopySet(t *testing.T, element, other *crypto.Element) { @@ -41,12 +40,12 @@ func testElementCopySet(t *testing.T, element, other *crypto.Element) { // Verify than operations on one don't affect the other element.Add(element) if element.Equal(other) == 1 { - t.Fatalf("Unexpected equality") + t.Fatalf(errUnExpectedEquality) } other.Double().Double() if element.Equal(other) == 1 { - t.Fatalf("Unexpected equality") + t.Fatalf(errUnExpectedEquality) } // Verify setting to nil sets to identity @@ -166,7 +165,27 @@ func TestElement_EncodedLength(t *testing.T) { func TestElement_Decode_OutOfBounds(t *testing.T) { testAll(t, func(group *testGroup) { - expected := errors.New("invalid point encoding") + decodeErr := "element Decode: " + unmarshallBinaryErr := "element UnmarshalBinary: " + errMessage := "" + switch group.group { + case crypto.Ristretto255Sha512: + errMessage = "invalid Ristretto encoding" + case crypto.P256Sha256: + errMessage = "invalid P256 element encoding" + case crypto.P384Sha384: + errMessage = "invalid P384Element encoding" + case crypto.P521Sha512: + errMessage = "invalid P521Element encoding" + case crypto.Edwards25519Sha512: + errMessage = "edwards25519: invalid point encoding" + case crypto.Secp256k1: + errMessage = "invalid point encoding" + } + + decodeErr += errMessage + unmarshallBinaryErr += errMessage + encoded := make([]byte, group.group.ElementLength()) y := big.NewInt(0) @@ -189,7 +208,13 @@ func TestElement_Decode_OutOfBounds(t *testing.T) { t.Fatalf("non registered group %s", group.group) } - if err := secp256k1.NewElement().Decode(encoded[:]); err == nil || err.Error() != expected.Error() { + expected := errors.New(decodeErr) + if err := group.group.NewElement().Decode(encoded[:]); err == nil || err.Error() != expected.Error() { + t.Errorf("expected error %q, got %v", expected, err) + } + + expected = errors.New(unmarshallBinaryErr) + if err := group.group.NewElement().UnmarshalBinary(encoded[:]); err == nil || err.Error() != expected.Error() { t.Errorf("expected error %q, got %v", expected, err) } }) @@ -272,10 +297,7 @@ func TestElement_Vectors_Mult(t *testing.T) { t.Fatalf("expected equality for %d", i) } - if err := s.SetInt(big.NewInt(int64(i + 2))); err != nil { - t.Fatal(err) - } - + s.SetUInt64(uint64(i + 2)) base.Base().Multiply(s) } }) @@ -297,6 +319,10 @@ func elementTestEqual(t *testing.T, g crypto.Group) { base := g.Base() base2 := g.Base() + if base.Equal(nil) != 0 { + t.Fatal(errUnExpectedEquality) + } + if base.Equal(base2) != 1 { t.Fatal(errExpectedEquality) } diff --git a/tests/groups_test.go b/tests/groups_test.go index 98bfece..cc46430 100644 --- a/tests/groups_test.go +++ b/tests/groups_test.go @@ -64,9 +64,9 @@ func TestNonAvailability(t *testing.T) { func TestGroup_Base(t *testing.T) { testAll(t, func(group *testGroup) { - if hex.EncodeToString(group.group.Base().Encode()) != group.basePoint { + if group.group.Base().Hex() != group.basePoint { t.Fatalf("Got wrong base element\n\tgot : %s\n\twant: %s", - hex.EncodeToString(group.group.Base().Encode()), + group.group.Base().Hex(), group.basePoint) } }) diff --git a/tests/h2c_test.go b/tests/h2c_test.go index 2796a92..69449e4 100644 --- a/tests/h2c_test.go +++ b/tests/h2c_test.go @@ -150,11 +150,11 @@ func (v *h2cVector) run(t *testing.T) { } func verifyEncoding(p *crypto.Element, function, expected string) error { - if hex.EncodeToString(p.Encode()) != expected { + if p.Hex() != expected { return fmt.Errorf("Unexpected %s output.\n\tExpected %q\n\tgot %q", function, expected, - hex.EncodeToString(p.Encode()), + p.Hex(), ) } diff --git a/tests/ristretto_hash_test.go b/tests/ristretto_hash_test.go index 37220e9..e154dde 100644 --- a/tests/ristretto_hash_test.go +++ b/tests/ristretto_hash_test.go @@ -76,7 +76,7 @@ func TestRistretto_HashToGroup(t *testing.T) { t.Fatalf( "Mappings do not match.\n\tExpected: %v\n\tActual: %v\n", hex.EncodeToString(v.encodedElement), - hex.EncodeToString(e.Encode()), + e.Hex(), ) } }) diff --git a/tests/ristretto_test.go b/tests/ristretto_test.go index ba33d34..4033a29 100644 --- a/tests/ristretto_test.go +++ b/tests/ristretto_test.go @@ -103,13 +103,8 @@ func TestRistrettoScalar(t *testing.T) { for _, tt := range ristrettoTests { t.Run(tt.name, func(t *testing.T) { // Grab the bytes of the encoding - encoding, err := hex.DecodeString(tt.scalar) - if err != nil { - t.Fatalf("#%s: bad hex encoding in test vector: %v", tt.name, err) - } - s := ristretto.Group{}.NewScalar() - err = s.Decode(encoding) + err := s.DecodeHex(tt.scalar) if tt.scal == false { if err == nil { @@ -160,15 +155,9 @@ func TestRistrettoElement(t *testing.T) { for _, tt := range ristrettoTests { t.Run(tt.name, func(t *testing.T) { - // Grab the bytes of the encoding - encoding, err := hex.DecodeString(tt.element) - if err != nil { - t.Fatalf("%s: bad hex encoding in test vector: %v", tt.name, err) - } - // Test decoding e := ristretto.Group{}.NewElement() - err = e.Decode(encoding) + err = e.DecodeHex(tt.element) if tt.elem == false { if err == nil { diff --git a/tests/scalar_test.go b/tests/scalar_test.go index 2c1f609..10a63e1 100644 --- a/tests/scalar_test.go +++ b/tests/scalar_test.go @@ -9,9 +9,13 @@ package group_test import ( + "bytes" + "encoding/binary" "encoding/hex" "errors" + "math" "math/big" + "slices" "testing" "github.com/bytemare/crypto" @@ -81,12 +85,12 @@ func testScalarCopySet(t *testing.T, scalar, other *crypto.Scalar) { // Verify than operations on one don't affect the other scalar.Add(scalar) if scalar.Equal(other) == 1 { - t.Fatalf("Unexpected equality") + t.Fatalf(errUnExpectedEquality) } other.Invert() if scalar.Equal(other) == 1 { - t.Fatalf("Unexpected equality") + t.Fatalf(errUnExpectedEquality) } // Verify setting to nil sets to 0 @@ -112,39 +116,31 @@ func TestScalarSet(t *testing.T) { }) } -func TestScalarSetInt(t *testing.T) { +func TestScalar_SetUInt64(t *testing.T) { testAll(t, func(group *testGroup) { - i := big.NewInt(0) - - s := group.group.NewScalar() - if err := s.SetInt(i); err != nil { - t.Fatal(err) - } - + s := group.group.NewScalar().SetUInt64(0) if !s.IsZero() { t.Fatal("expected 0") } - i = big.NewInt(1) - if err := s.SetInt(i); err != nil { - t.Fatal(err) - } - + s.SetUInt64(1) if s.Equal(group.group.NewScalar().One()) != 1 { t.Fatal("expected 1") } - order, ok := new(big.Int).SetString(group.group.Order(), 10) - if !ok { - t.Fatal("conversion error") - } + // uint64 max value is 18,446,744,073,709,551,615 + s.SetUInt64(math.MaxUint64) + ref := make([]byte, group.group.ScalarLength()) - if err := s.SetInt(order); err != nil { - t.Fatal(err) + switch group.group { + case crypto.Ristretto255Sha512, crypto.Edwards25519Sha512: + binary.LittleEndian.PutUint64(ref, math.MaxUint64) + default: + binary.BigEndian.PutUint64(ref[group.group.ScalarLength()-8:], math.MaxUint64) } - if !s.IsZero() { - t.Fatalf("expected 0, got %v\n%v", s.Encode(), order) + if bytes.Compare(ref, s.Encode()) != 0 { + t.Fatalf("expected %q, got %q", hex.EncodeToString(ref), s.Hex()) } }) } @@ -164,16 +160,36 @@ func TestScalar_EncodedLength(t *testing.T) { func TestScalar_Decode_OutOfBounds(t *testing.T) { testAll(t, func(group *testGroup) { + decodeErrPrefix := "scalar Decode: " + unmarshallBinaryErrPrefix := "scalar UnmarshalBinary: " + switch group.group { + case crypto.Ristretto255Sha512: + unmarshallBinaryErrPrefix += "ristretto: " + case crypto.P256Sha256, crypto.P384Sha384, crypto.P521Sha512: + unmarshallBinaryErrPrefix += "nist: " + case crypto.Edwards25519Sha512: + unmarshallBinaryErrPrefix += "edwards25519: " + case crypto.Secp256k1: + break + } + // Decode invalid length + errMessage := "invalid scalar length" encoded := make([]byte, 2) big.NewInt(1).FillBytes(encoded) - expected := errors.New("scalar Decode: invalid scalar length") + expected := errors.New(decodeErrPrefix + errMessage) if err := group.group.NewScalar().Decode(encoded); err == nil || err.Error() != expected.Error() { t.Errorf("expected error %q, got %v", expected, err) } + expected = errors.New(unmarshallBinaryErrPrefix + errMessage) + if err := group.group.NewScalar().UnmarshalBinary(encoded); err == nil || err.Error() != expected.Error() { + t.Errorf("expected error %q, got %v", expected, err) + } + // Decode a scalar higher than order + errMessage = "invalid scalar encoding" encoded = make([]byte, group.group.ScalarLength()) order, ok := new(big.Int).SetString(group.group.Order(), 0) @@ -184,10 +200,15 @@ func TestScalar_Decode_OutOfBounds(t *testing.T) { order.Add(order, big.NewInt(1)) order.FillBytes(encoded) - expected = errors.New("scalar Decode: invalid scalar encoding") + expected = errors.New(decodeErrPrefix + errMessage) if err := group.group.NewScalar().Decode(encoded); err == nil || err.Error() != expected.Error() { t.Errorf("expected error %q, got %v", expected, err) } + + expected = errors.New(unmarshallBinaryErrPrefix + errMessage) + if err := group.group.NewScalar().UnmarshalBinary(encoded); err == nil || err.Error() != expected.Error() { + t.Errorf("expected error %q, got %v", expected, err) + } }) } @@ -239,7 +260,7 @@ func scalarTestOne(t *testing.T, g crypto.Group) { func scalarTestRandom(t *testing.T, g crypto.Group) { r := g.NewScalar().Random() if r.Equal(g.NewScalar().Zero()) == 1 { - t.Fatalf("random scalar is zero: %v", hex.EncodeToString(r.Encode())) + t.Fatalf("random scalar is zero: %v", r.Hex()) } } @@ -247,6 +268,10 @@ func scalarTestEqual(t *testing.T, g crypto.Group) { zero := g.NewScalar().Zero() zero2 := g.NewScalar().Zero() + if g.NewScalar().Random().Equal(nil) != 0 { + t.Fatal(errUnExpectedEquality) + } + if zero.Equal(zero2) != 1 { t.Fatal(errExpectedEquality) } @@ -259,7 +284,7 @@ func scalarTestEqual(t *testing.T, g crypto.Group) { random2 := g.NewScalar().Random() if random.Equal(random2) == 1 { - t.Fatal("unexpected equality") + t.Fatal(errUnExpectedEquality) } } @@ -268,6 +293,10 @@ func scalarTestLessOrEqual(t *testing.T, g crypto.Group) { one := g.NewScalar().One() two := g.NewScalar().One().Add(one) + if g.NewScalar().Random().LessOrEqual(nil) != 0 { + t.Fatal(errUnExpectedEquality) + } + if zero.LessOrEqual(one) != 1 { t.Fatal("expected 0 < 1") } @@ -344,9 +373,7 @@ func scalarTestPow(t *testing.T, g crypto.Group) { s = g.NewScalar().One() s.Add(s.Copy().One()) s2 := s.Copy().Multiply(s) - if err := exp.SetInt(big.NewInt(2)); err != nil { - t.Fatal(err) - } + exp.SetUInt64(2) if s.Pow(exp).Equal(s2) != 1 { t.Fatal("expected s**2 = s*s") @@ -356,55 +383,30 @@ func scalarTestPow(t *testing.T, g crypto.Group) { s = g.NewScalar().Random() s3 := s.Copy().Multiply(s) s3.Multiply(s) - _ = exp.SetInt(big.NewInt(3)) + exp.SetUInt64(3) if s.Pow(exp).Equal(s3) != 1 { t.Fatal("expected s**3 = s*s*s") } // 5**7 = 78125 = 00000000 00000001 00110001 00101101 = 1 49 45 - iBase := big.NewInt(5) - iExp := big.NewInt(7) - order, ok := new(big.Int).SetString(g.Order(), 0) - if !ok { - t.Fatal(ok) - } - iResult := new(big.Int).Exp(iBase, iExp, order) - result := g.NewScalar() - if err := result.SetInt(iResult); err != nil { - t.Fatal(err) - } + result := g.NewScalar().SetUInt64(uint64(math.Pow(5, 7))) + s.SetUInt64(5) + exp.SetUInt64(7) - if err := s.SetInt(iBase); err != nil { - t.Fatal(err) - } - if err := exp.SetInt(iExp); err != nil { - t.Fatal(err) - } res := s.Pow(exp) if res.Equal(result) != 1 { t.Fatal("expected 5**7 = 78125") } // 3**255 = 11F1B08E87EC42C5D83C3218FC83C41DCFD9F4428F4F92AF1AAA80AA46162B1F71E981273601F4AD1DD4709B5ACA650265A6AB - iBase = big.NewInt(3) - iExp = big.NewInt(255) - order, ok = new(big.Int).SetString(g.Order(), 0) - if !ok { - t.Fatal(ok) - } - iResult = new(big.Int).Exp(iBase, iExp, order) - result = g.NewScalar() - if err := result.SetInt(iResult); err != nil { - t.Fatal(err) - } + iBase := big.NewInt(3) + iExp := big.NewInt(255) + result = bigIntExp(t, g, iBase, iExp) + + s.SetUInt64(3) + exp.SetUInt64(255) - if err := s.SetInt(iBase); err != nil { - t.Fatal(err) - } - if err := exp.SetInt(iExp); err != nil { - t.Fatal(err) - } res = s.Pow(exp) if res.Equal(result) != 1 { t.Fatal( @@ -416,18 +418,10 @@ func scalarTestPow(t *testing.T, g crypto.Group) { // 7945232487465**513 iBase.SetInt64(7945232487465) iExp.SetInt64(513) - iResult = iResult.Exp(iBase, iExp, order) - if err := result.SetInt(iResult); err != nil { - t.Fatal(err) - } + result = bigIntExp(t, g, iBase, iExp) - if err := s.SetInt(iBase); err != nil { - t.Fatal(err) - } - - if err := exp.SetInt(iExp); err != nil { - t.Fatal(err) - } + s.SetUInt64(7945232487465) + exp.SetUInt64(513) res = s.Pow(exp) if res.Equal(result) != 1 { @@ -460,17 +454,36 @@ func scalarTestPow(t *testing.T, g crypto.Group) { iExp.SetBytes(exp.Encode()) } - iResult.Exp(iBase, iExp, order) - - if err := result.SetInt(iResult); err != nil { - t.Fatal(err) - } + result = bigIntExp(t, g, iBase, iExp) if s.Pow(exp).Equal(result) != 1 { t.Fatal("expected equality on random numbers") } } +func bigIntExp(t *testing.T, g crypto.Group, base, exp *big.Int) *crypto.Scalar { + order, ok := new(big.Int).SetString(g.Order(), 0) + if !ok { + t.Fatal(ok) + } + + r := new(big.Int).Exp(base, exp, order) + + b := make([]byte, g.ScalarLength()) + r.FillBytes(b) + + if g == crypto.Ristretto255Sha512 || g == crypto.Edwards25519Sha512 { + slices.Reverse(b) + } + + result := g.NewScalar() + if err := result.Decode(b); err != nil { + t.Fatal(err) + } + + return result +} + func scalarTestInvert(t *testing.T, g crypto.Group) { s := g.NewScalar().Random() sqr := s.Copy().Multiply(s) diff --git a/tests/utils_test.go b/tests/utils_test.go index c885c42..b669018 100644 --- a/tests/utils_test.go +++ b/tests/utils_test.go @@ -99,19 +99,31 @@ func decodeElement(t *testing.T, g crypto.Group, input string) *crypto.Element { type serde interface { Encode() []byte Decode(data []byte) error + Hex() string + DecodeHex(h string) error encoding.BinaryMarshaler encoding.BinaryUnmarshaler } func testEncoding(t *testing.T, thing1, thing2 serde) { + // empty string + if err := thing2.DecodeHex(""); err == nil { + t.Fatal("expected error on empty string") + } + encoded := thing1.Encode() marshalled, _ := thing1.MarshalBinary() + hexed := thing1.Hex() if !bytes.Equal(encoded, marshalled) { t.Fatalf("Encode() and MarshalBinary() are expected to have the same output."+ "\twant: %v\tgot : %v", encoded, marshalled) } + if hex.EncodeToString(encoded) != hexed { + t.Fatalf("Failed hex encoding, want %q, got %q", hex.EncodeToString(encoded), hexed) + } + if err := thing2.Decode(nil); err == nil { t.Fatal("expected error on Decode() with nil input") } @@ -123,6 +135,10 @@ func testEncoding(t *testing.T, thing1, thing2 serde) { if err := thing2.UnmarshalBinary(encoded); err != nil { t.Fatalf("UnmarshalBinary() failed on a valid encoding: %v", err) } + + if err := thing2.DecodeHex(hexed); err != nil { + t.Fatalf("DecodeHex() failed on valid hex encoding: %v", err) + } } func TestEncoding(t *testing.T) { @@ -136,3 +152,58 @@ func TestEncoding(t *testing.T) { testEncoding(t, element, g.NewElement()) }) } + +func testDecodingHexFails(t *testing.T, thing1, thing2 serde) { + // empty string + if err := thing2.DecodeHex(""); err == nil { + t.Fatal("expected error on empty string") + } + + // malformed string + hexed := thing1.Hex() + malformed := []rune(hexed) + malformed[0] = []rune("_")[0] + + if err := thing2.DecodeHex(string(malformed)); err == nil { + t.Fatal("expected error on malformed string") + } else { + t.Log(err) + } +} + +func TestEncoding_Hex_Fails(t *testing.T) { + testAll(t, func(group *testGroup) { + g := group.group + scalar := g.NewScalar().Random() + testEncoding(t, scalar, g.NewScalar()) + + scalar = g.NewScalar().Random() + element := g.Base().Multiply(scalar) + testEncoding(t, element, g.NewElement()) + + // Hex fails + testDecodingHexFails(t, scalar, g.NewScalar()) + testDecodingHexFails(t, element, g.NewElement()) + + // Doesn't yield the same decoded result + scalar = g.NewScalar().Random() + s := g.NewScalar() + if err := s.DecodeHex(scalar.Hex()); err != nil { + t.Fatalf("unexpected error on valid encoding: %s", err) + } + + if s.Equal(scalar) != 1 { + t.Fatal(errExpectedEquality) + } + + element = g.Base().Multiply(scalar) + e := g.NewElement() + if err := e.DecodeHex(element.Hex()); err != nil { + t.Fatalf("unexpected error on valid encoding: %s", err) + } + + if e.Equal(element) != 1 { + t.Fatal(errExpectedEquality) + } + }) +}