Skip to content

Commit

Permalink
fix: strict ModReduce in emulated fields (#1224)
Browse files Browse the repository at this point in the history
* feat: add modReduced field to Element

* feat: add ReduceStrict method

* feat: add ToBitsCanonical method

* fix: in IsZero check against p also

* test: implement test cases

* feat: add WithCanonicalRepresentation option

* docs: refer to ReduceStrict

* refactor: use std slices

* docs: add TODO for improving ToBitsCanonical method

* feat: set modReduced to false during element init

* fix: always return if modReduced

* feat: remove most significant bit when modulus power of two
  • Loading branch information
ivokub authored Jul 25, 2024
1 parent c36ff9e commit af21593
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 61 deletions.
20 changes: 20 additions & 0 deletions std/algebra/algopts/algopts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type algebraCfg struct {
NbScalarBits int
FoldMulti bool
CompleteArithmetic bool
ToBitsCanonical bool
}

// AlgebraOption allows modifying algebraic operation behaviour.
Expand Down Expand Up @@ -57,6 +58,25 @@ func WithCompleteArithmetic() AlgebraOption {
}
}

// WithCanonicalBitRepresentation enforces the marshalling methods to assert
// that the bit representation is in canonical form. For field elements this
// means that the bits represent a number less than the modulus.
//
// This option is useful when performing direct comparison between the bit form
// of two elements. It can be avoided when the bit representation is used in
// other cases, such as computing a challenge using a hash function, where
// non-canonical bit representation leads to incorrect challenge (which in turn
// makes the verification fail).
func WithCanonicalBitRepresentation() AlgebraOption {
return func(ac *algebraCfg) error {
if ac.ToBitsCanonical {
return fmt.Errorf("WithCanonicalBitRepresentation already set")
}
ac.ToBitsCanonical = true
return nil
}
}

// NewConfig applies all given options and returns a configuration to be used.
func NewConfig(opts ...AlgebraOption) (*algebraCfg, error) {
ret := new(algebraCfg)
Expand Down
37 changes: 26 additions & 11 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package sw_emulated
import (
"fmt"
"math/big"
"slices"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/algopts"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
"golang.org/x/exp/slices"
)

// New returns a new [Curve] instance over the base field Base and scalar field
Expand Down Expand Up @@ -101,26 +101,41 @@ type AffinePoint[Base emulated.FieldParams] struct {

// MarshalScalar marshals the scalar into bits. Compatible with scalar
// marshalling in gnark-crypto.
func (c *Curve[B, S]) MarshalScalar(s emulated.Element[S]) []frontend.Variable {
func (c *Curve[B, S]) MarshalScalar(s emulated.Element[S], opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
var fr S
nbBits := 8 * ((fr.Modulus().BitLen() + 7) / 8)
sReduced := c.scalarApi.Reduce(&s)
res := c.scalarApi.ToBits(sReduced)[:nbBits]
for i, j := 0, nbBits-1; i < j; {
res[i], res[j] = res[j], res[i]
i++
j--
var sReduced *emulated.Element[S]
if cfg.ToBitsCanonical {
sReduced = c.scalarApi.ReduceStrict(&s)
} else {
sReduced = c.scalarApi.Reduce(&s)
}
res := c.scalarApi.ToBits(sReduced)[:nbBits]
slices.Reverse(res)
return res
}

// MarshalG1 marshals the affine point into bits. The output is compatible with
// the point marshalling in gnark-crypto.
func (c *Curve[B, S]) MarshalG1(p AffinePoint[B]) []frontend.Variable {
func (c *Curve[B, S]) MarshalG1(p AffinePoint[B], opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
var fp B
nbBits := 8 * ((fp.Modulus().BitLen() + 7) / 8)
x := c.baseApi.Reduce(&p.X)
y := c.baseApi.Reduce(&p.Y)
var x, y *emulated.Element[B]
if cfg.ToBitsCanonical {
x = c.baseApi.ReduceStrict(&p.X)
y = c.baseApi.ReduceStrict(&p.Y)
} else {
x = c.baseApi.Reduce(&p.X)
y = c.baseApi.Reduce(&p.Y)
}
bx := c.baseApi.ToBits(x)[:nbBits]
by := c.baseApi.ToBits(y)[:nbBits]
slices.Reverse(bx)
Expand Down
4 changes: 2 additions & 2 deletions std/algebra/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ type Curve[FR emulated.FieldParams, G1El G1ElementT] interface {

// MarshalG1 returns the binary decomposition G1.X || G1.Y. It matches the
// output of gnark-crypto's Marshal method on G1 points.
MarshalG1(G1El) []frontend.Variable
MarshalG1(G1El, ...algopts.AlgebraOption) []frontend.Variable

// MarshalScalar returns the binary decomposition of the argument.
MarshalScalar(emulated.Element[FR]) []frontend.Variable
MarshalScalar(emulated.Element[FR], ...algopts.AlgebraOption) []frontend.Variable

// Select sets p1 if b=1, p2 if b=0, and returns it. b must be boolean constrained
Select(b frontend.Variable, p1 *G1El, p2 *G1El) *G1El
Expand Down
34 changes: 24 additions & 10 deletions std/algebra/native/sw_bls12377/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sw_bls12377
import (
"fmt"
"math/big"
"slices"

"github.com/consensys/gnark-crypto/ecc"
bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377"
Expand Down Expand Up @@ -36,25 +37,38 @@ func NewCurve(api frontend.API) (*Curve, error) {
}

// MarshalScalar returns
func (c *Curve) MarshalScalar(s Scalar) []frontend.Variable {
func (c *Curve) MarshalScalar(s Scalar, opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
nbBits := 8 * ((ScalarField{}.Modulus().BitLen() + 7) / 8)
ss := c.fr.Reduce(&s)
x := c.fr.ToBits(ss)
for i, j := 0, nbBits-1; i < j; {
x[i], x[j] = x[j], x[i]
i++
j--
var ss *emulated.Element[ScalarField]
if cfg.ToBitsCanonical {
ss = c.fr.ReduceStrict(&s)
} else {
ss = c.fr.Reduce(&s)
}
x := c.fr.ToBits(ss)[:nbBits]
slices.Reverse(x)
return x
}

// MarshalG1 returns [P.X || P.Y] in binary. Both P.X and P.Y are
// in little endian.
func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable {
func (c *Curve) MarshalG1(P G1Affine, opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
nbBits := 8 * ((ecc.BLS12_377.BaseField().BitLen() + 7) / 8)
bOpts := []bits.BaseConversionOption{bits.WithNbDigits(nbBits)}
if !cfg.ToBitsCanonical {
bOpts = append(bOpts, bits.OmitModulusCheck())
}
res := make([]frontend.Variable, 2*nbBits)
x := bits.ToBinary(c.api, P.X, bits.WithNbDigits(nbBits))
y := bits.ToBinary(c.api, P.Y, bits.WithNbDigits(nbBits))
x := bits.ToBinary(c.api, P.X, bOpts...)
y := bits.ToBinary(c.api, P.Y, bOpts...)
for i := 0; i < nbBits; i++ {
res[i] = x[nbBits-1-i]
res[i+nbBits] = y[nbBits-1-i]
Expand Down
34 changes: 24 additions & 10 deletions std/algebra/native/sw_bls24315/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sw_bls24315
import (
"fmt"
"math/big"
"slices"

"github.com/consensys/gnark-crypto/ecc"
bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315"
Expand Down Expand Up @@ -36,25 +37,38 @@ func NewCurve(api frontend.API) (*Curve, error) {
}

// MarshalScalar returns
func (c *Curve) MarshalScalar(s Scalar) []frontend.Variable {
func (c *Curve) MarshalScalar(s Scalar, opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
nbBits := 8 * ((ScalarField{}.Modulus().BitLen() + 7) / 8)
ss := c.fr.Reduce(&s)
x := c.fr.ToBits(ss)
for i, j := 0, nbBits-1; i < j; {
x[i], x[j] = x[j], x[i]
i++
j--
var ss *emulated.Element[ScalarField]
if cfg.ToBitsCanonical {
ss = c.fr.ReduceStrict(&s)
} else {
ss = c.fr.Reduce(&s)
}
x := c.fr.ToBits(ss)[:nbBits]
slices.Reverse(x)
return x
}

// MarshalG1 returns [P.X || P.Y] in binary. Both P.X and P.Y are
// in little endian.
func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable {
func (c *Curve) MarshalG1(P G1Affine, opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
nbBits := 8 * ((ecc.BLS24_315.BaseField().BitLen() + 7) / 8)
bOpts := []bits.BaseConversionOption{bits.WithNbDigits(nbBits)}
if !cfg.ToBitsCanonical {
bOpts = append(bOpts, bits.OmitModulusCheck())
}
res := make([]frontend.Variable, 2*nbBits)
x := bits.ToBinary(c.api, P.X, bits.WithNbDigits(nbBits))
y := bits.ToBinary(c.api, P.Y, bits.WithNbDigits(nbBits))
x := bits.ToBinary(c.api, P.X, bOpts...)
y := bits.ToBinary(c.api, P.Y, bOpts...)
for i := 0; i < nbBits; i++ {
res[i] = x[nbBits-1-i]
res[i+nbBits] = y[nbBits-1-i]
Expand Down
12 changes: 12 additions & 0 deletions std/math/emulated/element.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ type Element[T FieldParams] struct {
// enforcement info in the Element to prevent modifying the witness.
internal bool

// modReduced indicates that the element has been reduced modulo the modulus
// and we have asserted that the integer value of the element is strictly
// less than the modulus. This is required for some operations which depend
// on the bit-representation of the element (ToBits, exponentiation etc.).
modReduced bool

isEvaluated bool
evaluation frontend.Variable `gnark:"-"`
}
Expand Down Expand Up @@ -95,6 +101,11 @@ func (e *Element[T]) GnarkInitHook() {
*e = ValueOf[T](0)
e.internal = false // we need to constrain in later.
}
// set modReduced to false - in case the circuit is compiled we may change
// the value for an existing element. If we don't reset it here, then during
// second compilation we may take a shortPath where we assume that modReduce
// flag is set.
e.modReduced = false
}

// copy makes a deep copy of the element.
Expand All @@ -104,5 +115,6 @@ func (e *Element[T]) copy() *Element[T] {
copy(r.Limbs, e.Limbs)
r.overflow = e.overflow
r.internal = e.internal
r.modReduced = e.modReduced
return &r
}
Loading

0 comments on commit af21593

Please sign in to comment.