From 840a8676a06ca0f6e7a56f9226c7df8cb66bb694 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Mon, 8 Jan 2024 21:14:28 +0800 Subject: [PATCH 01/16] added circuit implementation of ExpandMsgXmd --- std/hash/tofield/doc.go | 3 + std/hash/tofield/expand.go | 102 +++++++++++++++++++++ std/hash/tofield/expand_test.go | 158 ++++++++++++++++++++++++++++++++ 3 files changed, 263 insertions(+) create mode 100644 std/hash/tofield/doc.go create mode 100644 std/hash/tofield/expand.go create mode 100644 std/hash/tofield/expand_test.go diff --git a/std/hash/tofield/doc.go b/std/hash/tofield/doc.go new file mode 100644 index 000000000..b1612201a --- /dev/null +++ b/std/hash/tofield/doc.go @@ -0,0 +1,3 @@ +// Package tofield provides ZKP circuits for expanding messages to field elements, according to +// RFC9380 (section 5.3.1). +package tofield diff --git a/std/hash/tofield/expand.go b/std/hash/tofield/expand.go new file mode 100644 index 000000000..f6d3e07e1 --- /dev/null +++ b/std/hash/tofield/expand.go @@ -0,0 +1,102 @@ +package tofield + +import ( + "errors" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/sha2" + "github.com/consensys/gnark/std/math/uints" +) + +const ( + block_size = 64 +) + +// ExpandMsgXmd expands msg to a slice of lenInBytes bytes according to RFC9380 (section 5.3.1) +// Spec: https://datatracker.ietf.org/doc/html/rfc9380#name-expand_message_xmd (hashutils.go) +// Implementation was adapted from gnark-crypto/field/hash.ExpandMsgXmd. +func ExpandMsgXmd(api frontend.API, msg []uints.U8, dst []byte, lenInBytes int) ([]uints.U8, error) { + h, e := sha2.New(api) + if e != nil { + return nil, e + } + + ell := (lenInBytes + h.Size() - 1) / h.Size() // ceil(len_in_bytes / b_in_bytes) + if ell > 255 { + return nil, errors.New("invalid lenInBytes") + } + if len(dst) > 255 { + return nil, errors.New("invalid domain size (>255 bytes)") + } + sizeDomain := uint8(len(dst)) + + dst_prime := make([]uints.U8, len(dst)+1) + copy(dst_prime, uints.NewU8Array(dst)) + dst_prime[len(dst)] = uints.NewU8(uint8(sizeDomain)) + + Z_pad_raw := make([]uint8, block_size) + Z_pad := uints.NewU8Array(Z_pad_raw) + h.Write(Z_pad) + h.Write(msg) + h.Write([]uints.U8{uints.NewU8(uint8(lenInBytes >> 8)), uints.NewU8(uint8(lenInBytes)), uints.NewU8(0)}) + h.Write(dst_prime) + b0 := h.Sum() + + h, e = sha2.New(api) + if e != nil { + return nil, e + } + h.Write(b0) + h.Write([]uints.U8{uints.NewU8(1)}) + h.Write(dst_prime) + b1 := h.Sum() + + res := make([]uints.U8, lenInBytes) + copy(res[:h.Size()], b1) + + for i := 2; i <= ell; i++ { + h, e = sha2.New(api) + if e != nil { + return nil, e + } + + // b_i = H(strxor(b₀, b_(i - 1)) ∥ I2OSP(i, 1) ∥ DST_prime) + strxor := make([]uints.U8, h.Size()) + for j := 0; j < h.Size(); j++ { + strxor[j], e = xor(api, b0[j], b1[j]) + if e != nil { + return res, e + } + } + h.Write(strxor) + h.Write([]uints.U8{uints.NewU8(uint8(i))}) + h.Write(dst_prime) + b1 = h.Sum() + copy(res[h.Size()*(i-1):min(h.Size()*i, len(res))], b1) + } + + return res, nil +} + +func xor(api frontend.API, a, b uints.U8) (uints.U8, error) { + aBits := api.ToBinary(a.Val, 8) + bBits := api.ToBinary(b.Val, 8) + cBits := make([]frontend.Variable, 8) + + for i := 0; i < 8; i++ { + cBits[i] = api.Xor(aBits[i], bBits[i]) + } + + uapi, err := uints.New[uints.U32](api) + if err != nil { + return uints.NewU8(255), err + } + return uapi.ByteValueOf(api.FromBinary(cBits...)), nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/std/hash/tofield/expand_test.go b/std/hash/tofield/expand_test.go new file mode 100644 index 000000000..1df5271c5 --- /dev/null +++ b/std/hash/tofield/expand_test.go @@ -0,0 +1,158 @@ +package tofield + +import ( + "encoding/hex" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/test" +) + +type expandMsgXmdCircuit struct { + Msg []uints.U8 + Dst []uint8 + Len int + Expected []uints.U8 +} + +type expandMsgXmdTestCase struct { + msg string + lenInBytes int + uniformBytesHex string +} + +func (c *expandMsgXmdCircuit) Define(api frontend.API) error { + uapi, err := uints.New[uints.U32](api) + if err != nil { + return err + } + expanded, err := ExpandMsgXmd(api, c.Msg, c.Dst, c.Len) + if err != nil { + return err + } + + for i := 0; i < c.Len; i++ { + uapi.ByteAssertEq(expanded[i], c.Expected[i]) + } + + return nil +} + +// adapted from gnark-crypto/field/hash/hashutils_test.go +func TestExpandMsgXmd(t *testing.T) { + //name := "expand_message_xmd" + dst := "QUUX-V01-CS02-with-expander-SHA256-128" + //hash := "SHA256" + //k := 128 + + testCases := []expandMsgXmdTestCase{ + { + "", + 0x20, + "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235", + }, + + { + "abc", + 0x20, + "d8ccab23b5985ccea865c6c97b6e5b8350e794e603b4b97902f53a8a0d605615", + }, + + { + "abcdef0123456789", + 0x20, + "eff31487c770a893cfb36f912fbfcbff40d5661771ca4b2cb4eafe524333f5c1", + }, + + { + "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", + 0x20, + "b23a1d2b4d97b2ef7785562a7e8bac7eed54ed6e97e29aa51bfe3f12ddad1ff9", + }, + + { + "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + 0x20, + "4623227bcc01293b8c130bf771da8c298dede7383243dc0993d2d94823958c4c", + }, + { + "", + 0x80, + "af84c27ccfd45d41914fdff5df25293e221afc53d8ad2ac06d5e3e29485dadbee0d121587713a3e0dd4d5e69e93eb7cd4f5df4cd103e188cf60cb02edc3edf18eda8576c412b18ffb658e3dd6ec849469b979d444cf7b26911a08e63cf31f9dcc541708d3491184472c2c29bb749d4286b004ceb5ee6b9a7fa5b646c993f0ced", + }, + { + "", + 0x20, + "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235", + }, + { + "abc", + 0x80, + "abba86a6129e366fc877aab32fc4ffc70120d8996c88aee2fe4b32d6c7b6437a647e6c3163d40b76a73cf6a5674ef1d890f95b664ee0afa5359a5c4e07985635bbecbac65d747d3d2da7ec2b8221b17b0ca9dc8a1ac1c07ea6a1e60583e2cb00058e77b7b72a298425cd1b941ad4ec65e8afc50303a22c0f99b0509b4c895f40", + }, + { + "abcdef0123456789", + 0x80, + "ef904a29bffc4cf9ee82832451c946ac3c8f8058ae97d8d629831a74c6572bd9ebd0df635cd1f208e2038e760c4994984ce73f0d55ea9f22af83ba4734569d4bc95e18350f740c07eef653cbb9f87910d833751825f0ebefa1abe5420bb52be14cf489b37fe1a72f7de2d10be453b2c9d9eb20c7e3f6edc5a60629178d9478df", + }, + { + "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", + 0x80, + "80be107d0884f0d881bb460322f0443d38bd222db8bd0b0a5312a6fedb49c1bbd88fd75d8b9a09486c60123dfa1d73c1cc3169761b17476d3c6b7cbbd727acd0e2c942f4dd96ae3da5de368d26b32286e32de7e5a8cb2949f866a0b80c58116b29fa7fabb3ea7d520ee603e0c25bcaf0b9a5e92ec6a1fe4e0391d1cdbce8c68a", + }, + { + "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + 0x80, + "546aff5444b5b79aa6148bd81728704c32decb73a3ba76e9e75885cad9def1d06d6792f8a7d12794e90efed817d96920d728896a4510864370c207f99bd4a608ea121700ef01ed879745ee3e4ceef777eda6d9e5e38b90c86ea6fb0b36504ba4a45d22e86f6db5dd43d98a294bebb9125d5b794e9d2a81181066eb954966a487", + }, + //test cases not in the standard + { + "", + 0x30, + "3808e9bb0ade2df3aa6f1b459eb5058a78142f439213ddac0c97dcab92ae5a8408d86b32bbcc87de686182cbdf65901f", + }, + { + "abc", + 0x30, + "2b877f5f0dfd881405426c6b87b39205ef53a548b0e4d567fc007cb37c6fa1f3b19f42871efefca518ac950c27ac4e28", + }, + { + "abcdef0123456789", + 0x30, + "226da1780b06e59723714f80da9a63648aebcfc1f08e0db87b5b4d16b108da118214c1450b0e86f9cefeb44903fd3aba", + }, + { + "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", + 0x30, + "12b23ae2e888f442fd6d0d85d90a0d7ed5337d38113e89cdc7c22db91bd0abaec1023e9a8f0ef583a111104e2f8a0637", + }, + { + "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + 0x30, + "1aaee90016547a85ab4dc55e4f78a364c2e239c0e58b05753453c63e6e818334005e90d9ce8f047bddab9fbb315f8722", + }, + } + + for _, testCase := range testCases { + uniformBytes := make([]uint8, len(testCase.uniformBytesHex)>>1) + hex.Decode(uniformBytes, []uint8(testCase.uniformBytesHex)) + witness := expandMsgXmdCircuit{ + Msg: uints.NewU8Array([]uint8(testCase.msg)), + Dst: []uint8(dst), + Len: testCase.lenInBytes, + Expected: uints.NewU8Array(uniformBytes), + } + circuit := expandMsgXmdCircuit{ + Msg: uints.NewU8Array(make([]uint8, len(testCase.msg))), + Dst: []uint8(dst), + Len: testCase.lenInBytes, + Expected: uints.NewU8Array(uniformBytes), + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + if err != nil { + t.Fatal(err) + } + } +} From 7fb1b95ec23297a54045cb46fd53cfbd3e37d484 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Sat, 27 Jan 2024 15:56:58 +0800 Subject: [PATCH 02/16] implemented HashToG2 for bls12-381 --- std/algebra/emulated/sw_bls12381/g2.go | 58 +++ std/algebra/emulated/sw_bls12381/g2_test.go | 71 ++++ .../emulated/sw_bls12381/hash_to_g2.go | 371 ++++++++++++++++++ .../emulated/sw_bls12381/hash_to_g2_test.go | 167 ++++++++ std/algebra/emulated/sw_bls12381/hints.go | 44 +++ std/hints.go | 2 + 6 files changed, 713 insertions(+) create mode 100644 std/algebra/emulated/sw_bls12381/hash_to_g2.go create mode 100644 std/algebra/emulated/sw_bls12381/hash_to_g2_test.go create mode 100644 std/algebra/emulated/sw_bls12381/hints.go diff --git a/std/algebra/emulated/sw_bls12381/g2.go b/std/algebra/emulated/sw_bls12381/g2.go index c607c5069..91daab066 100644 --- a/std/algebra/emulated/sw_bls12381/g2.go +++ b/std/algebra/emulated/sw_bls12381/g2.go @@ -13,6 +13,7 @@ type G2 struct { *fields_bls12381.Ext2 u1, w *emulated.Element[BaseField] v *fields_bls12381.E2 + api frontend.API } type g2AffP struct { @@ -50,6 +51,7 @@ func NewG2(api frontend.API) *G2 { w: &w, u1: &u1, v: &v, + api: api, } } @@ -96,6 +98,18 @@ func (g2 *G2) psi(q *G2Affine) *G2Affine { } } +func (g2 *G2) psi2(q *G2Affine) *G2Affine { + x := g2.Ext2.MulByElement(&q.P.X, g2.w) + y := g2.Ext2.Neg(&q.P.Y) + + return &G2Affine{ + P: g2AffP{ + X: *x, + Y: *y, + }, + } +} + func (g2 *G2) scalarMulBySeed(q *G2Affine) *G2Affine { z := g2.triple(q) @@ -136,6 +150,50 @@ func (g2 G2) add(p, q *G2Affine) *G2Affine { } } +// The add() function is not complete. If p == q, then λ is 0, while it should be (3p.x²)/2*p.y according to double() +// Moreover, the AssertIsEqual() check will fail within the DivUnchecked() call as the div hint will return 0. +// +// Provide addUnified to handle this situation at the cost of additional constraints. Useful when not certain if p != q. +// Note, ClearCofactor (covered in TestClearCofactorTestSolve of hash_to_g2_test.go) might fail without this function. +func (g2 G2) addUnified(p, q *G2Affine) *G2Affine { + // if p != q, compute λ = (q.y-p.y)/(q.x-p.x) + qypy := g2.Ext2.Sub(&q.P.Y, &p.P.Y) + qxpx := g2.Ext2.Sub(&q.P.X, &p.P.X) + + // else if p == q, compute λ = (3p.x²)/2*p.y + xx3a := g2.Square(&p.P.X) + xx3a = g2.MulByConstElement(xx3a, big.NewInt(3)) + y2 := g2.Double(&p.P.Y) + + z := g2.Ext2.IsZero(qxpx) + deltaX := g2.Select(z, y2, qxpx) + deltaY := g2.Select(z, xx3a, qypy) + λ := g2.Ext2.DivUnchecked(deltaY, deltaX) + + // if p == -q, result should zero + zeroConst := g2.Zero() + z1 := g2.Ext2.Add(&q.P.Y, &p.P.Y) + z2 := g2.Ext2.IsZero(z1) + z3 := g2.api.And(z, z2) + + // xr = λ²-p.x-q.x + λλ := g2.Ext2.Square(λ) + qxpx = g2.Ext2.Add(&p.P.X, &q.P.X) + xr := g2.Ext2.Sub(λλ, qxpx) + + // yr = λ(p.x-r.x) - p.y + pxrx := g2.Ext2.Sub(&p.P.X, xr) + λpxrx := g2.Ext2.Mul(λ, pxrx) + yr := g2.Ext2.Sub(λpxrx, &p.P.Y) + + return &G2Affine{ + P: g2AffP{ + X: *g2.Select(z3, zeroConst, xr), + Y: *g2.Select(z3, zeroConst, yr), + }, + } +} + func (g2 G2) neg(p *G2Affine) *G2Affine { xr := &p.P.X yr := g2.Ext2.Neg(&p.P.Y) diff --git a/std/algebra/emulated/sw_bls12381/g2_test.go b/std/algebra/emulated/sw_bls12381/g2_test.go index 9d4a90d0e..467925bea 100644 --- a/std/algebra/emulated/sw_bls12381/g2_test.go +++ b/std/algebra/emulated/sw_bls12381/g2_test.go @@ -37,6 +37,77 @@ func TestAddG2TestSolve(t *testing.T) { assert.NoError(err) } +func TestAddG2FailureCaseTestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bls12381.G2Affine + res.Double(&in1) + witness := addG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2Circuit{}, &witness, ecc.BN254.ScalarField()) + // the add() function cannot handle identical inputs + assert.Error(err) +} + +type addG2UnifiedCircuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *addG2UnifiedCircuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.addUnified(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestAddG2UnifiedTestSolveAdd(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + _, in2 := randomG1G2Affines() + var res bls12381.G2Affine + res.Add(&in1, &in2) + witness := addG2UnifiedCircuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2UnifiedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestAddG2UnifiedTestSolveDbl(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bls12381.G2Affine + res.Double(&in1) + witness := addG2UnifiedCircuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2UnifiedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestAddG2UnifiedTestSolveNeg(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var in2, res bls12381.G2Affine + in2.Neg(&in1) + res.Add(&in1, &in2) + witness := addG2UnifiedCircuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2UnifiedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + type doubleG2Circuit struct { In1 G2Affine Res G2Affine diff --git a/std/algebra/emulated/sw_bls12381/hash_to_g2.go b/std/algebra/emulated/sw_bls12381/hash_to_g2.go new file mode 100644 index 000000000..10266f383 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/hash_to_g2.go @@ -0,0 +1,371 @@ +package sw_bls12381 + +import ( + "math/big" + "slices" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" + "github.com/consensys/gnark/std/hash/tofield" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/std/math/uints" +) + +const ( + security_level = 128 + len_per_base_element = 64 +) + +func HashToG2(api frontend.API, msg []uints.U8, dst []byte) (*G2Affine, error) { + fp, e := emulated.NewField[emulated.BLS12381Fp](api) + if e != nil { + return &G2Affine{}, e + } + ext2 := fields_bls12381.NewExt2(api) + mapper := newMapper(api, ext2, fp) + g2 := NewG2(api) + + // Steps: + // 1. u = hash_to_field(msg, 2) + // 2. Q0 = map_to_curve(u[0]) + // 3. Q1 = map_to_curve(u[1]) + // 4. R = Q0 + Q1 # Point addition + // 5. P = clear_cofactor(R) + // 6. return P + lenPerBaseElement := len_per_base_element + lenInBytes := lenPerBaseElement * 4 + uniformBytes, e := tofield.ExpandMsgXmd(api, msg, dst, lenInBytes) + if e != nil { + return &G2Affine{}, e + } + + ele1 := bytesToElement(api, fp, uniformBytes[:lenPerBaseElement]) + ele2 := bytesToElement(api, fp, uniformBytes[lenPerBaseElement:lenPerBaseElement*2]) + ele3 := bytesToElement(api, fp, uniformBytes[lenPerBaseElement*2:lenPerBaseElement*3]) + ele4 := bytesToElement(api, fp, uniformBytes[lenPerBaseElement*3:]) + + // we will still do iso_map before point addition, as we do not have point addition in E' (yet) + Q0 := mapper.mapToCurve(fields_bls12381.E2{A0: *ele1, A1: *ele2}) + Q1 := mapper.mapToCurve(fields_bls12381.E2{A0: *ele3, A1: *ele4}) + Q0 = mapper.isogeny(&Q0.P.X, &Q0.P.Y) + Q1 = mapper.isogeny(&Q1.P.X, &Q1.P.Y) + + R := g2.addUnified(Q0, Q1) + + return clearCofactor(g2, fp, R), nil +} + +func bytesToElement(api frontend.API, fp *emulated.Field[emulated.BLS12381Fp], data []uints.U8) *emulated.Element[emulated.BLS12381Fp] { + // data in BE, need to convert to LE + slices.Reverse(data) + + bits := make([]frontend.Variable, len(data)*8) + for i := 0; i < len(data); i++ { + u8 := data[i] + u8Bits := api.ToBinary(u8.Val, 8) + for j := 0; j < 8; j++ { + bits[i*8+j] = u8Bits[j] + } + } + + cutoff := 17 + tailBits, headBits := bits[:cutoff*8], bits[cutoff*8:] + tail := fp.FromBits(tailBits...) + head := fp.FromBits(headBits...) + + byteMultiplier := big.NewInt(256) + headMultiplier := byteMultiplier.Exp(byteMultiplier, big.NewInt(int64(cutoff)), big.NewInt(0)) + head = fp.MulConst(head, headMultiplier) + + return fp.Add(head, tail) +} + +type sswuMapper struct { + A, B, Z fields_bls12381.E2 + ext2 *fields_bls12381.Ext2 + fp *emulated.Field[emulated.BLS12381Fp] + api frontend.API + iso *isogeny +} + +func newMapper(api frontend.API, ext2 *fields_bls12381.Ext2, fp *emulated.Field[emulated.BLS12381Fp]) *sswuMapper { + coeff_a := fields_bls12381.E2{ + A0: emulated.ValueOf[emparams.BLS12381Fp](0), + A1: emulated.ValueOf[emparams.BLS12381Fp](240), + } + coeff_b := fields_bls12381.E2{ + A0: emulated.ValueOf[emparams.BLS12381Fp](1012), + A1: emulated.ValueOf[emparams.BLS12381Fp](1012), + } + + one := emulated.ValueOf[emulated.BLS12381Fp](1) + two := emulated.ValueOf[emulated.BLS12381Fp](2) + zeta := fields_bls12381.E2{ + A0: *fp.Neg(&two), + A1: *fp.Neg(&one), + } + + return &sswuMapper{ + A: coeff_a, + B: coeff_b, + Z: zeta, + ext2: ext2, + fp: fp, + api: api, + iso: newIsogeny(), + } +} + +// Apply the Simplified SWU for the E' curve (RFC 9380 Section 6.6.3) +func (m sswuMapper) mapToCurve(u fields_bls12381.E2) *G2Affine { + // SSWU Steps: + // 1. tv1 = u^2 + tv1 := m.ext2.Square(&u) + // 2. tv1 = Z * tv1 + tv1 = m.ext2.Mul(&m.Z, tv1) + // 3. tv2 = tv1^2 + tv2 := m.ext2.Square(tv1) + // 4. tv2 = tv2 + tv1 + tv2 = m.ext2.Add(tv2, tv1) + // 5. tv3 = tv2 + 1 + tv3 := m.ext2.Add(tv2, m.ext2.One()) + // 6. tv3 = B * tv3 + tv3 = m.ext2.Mul(&m.B, tv3) + // 7. tv4 = CMOV(Z, -tv2, tv2 != 0) + s1 := m.ext2.IsZero(tv2) + tv4 := m.ext2.Select(s1, &m.Z, m.ext2.Neg(tv2)) + // 8. tv4 = A * tv4 + tv4 = m.ext2.Mul(&m.A, tv4) + // 9. tv2 = tv3^2 + tv2 = m.ext2.Square(tv3) + // 10. tv6 = tv4^2 + tv6 := m.ext2.Square(tv4) + // 11. tv5 = A * tv6 + tv5 := m.ext2.Mul(&m.A, tv6) + // 12. tv2 = tv2 + tv5 + tv2 = m.ext2.Add(tv2, tv5) + // 13. tv2 = tv2 * tv3 + tv2 = m.ext2.Mul(tv2, tv3) + // 14. tv6 = tv6 * tv4 + tv6 = m.ext2.Mul(tv6, tv4) + // 15. tv5 = B * tv6 + tv5 = m.ext2.Mul(&m.B, tv6) + // 16. tv2 = tv2 + tv5 + tv2 = m.ext2.Add(tv2, tv5) + // 17. x = tv1 * tv3 + x := m.ext2.Mul(tv1, tv3) + // 18. (is_gx1_square, y1) = sqrt_ratio(tv2, tv6) + isGx1Square, y1 := m.sqrtRatio(tv2, tv6) + // 19. y = tv1 * u + y := m.ext2.Mul(tv1, &u) + // 20. y = y * y1 + y = m.ext2.Mul(y, y1) + // 21. x = CMOV(x, tv3, is_gx1_square) + x = m.ext2.Select(isGx1Square, tv3, x) + // 22. y = CMOV(y, y1, is_gx1_square) + y = m.ext2.Select(isGx1Square, y1, y) + // 23. e1 = sgn0(u) == sgn0(y) + sgn0U := m.sgn0(&u) + sgn0Y := m.sgn0(y) + diff := m.api.Sub(sgn0U, sgn0Y) + e1 := m.api.IsZero(diff) + // 24. y = CMOV(-y, y, e1) + yNeg := m.ext2.Neg(y) + y = m.ext2.Select(e1, y, yNeg) + // 25. x = x / tv4 + x = m.ext2.DivUnchecked(x, tv4) + // 26. return (x, y) + return &G2Affine{ + P: g2AffP{X: *x, Y: *y}, + } +} + +func (m sswuMapper) sgn0(x *fields_bls12381.E2) frontend.Variable { + // Steps for sgn0_m_eq_2 + // 1. sign_0 = x_0 mod 2 + x0 := m.fp.ToBits(&x.A0) + sign0 := x0[0] + // 2. zero_0 = x_0 == 0 + zero0 := m.fp.IsZero(&x.A0) + // 3. sign_1 = x_1 mod 2 + x1 := m.fp.ToBits(&x.A1) + sign1 := x1[0] + // 4. s = sign_0 OR (zero_0 AND sign_1) # Avoid short-circuit logic ops + tv := m.api.And(zero0, sign1) + s := m.api.Or(sign0, tv) + // 5. return s + return s +} + +// Let's not mechanically translate the spec algorithm (Section F.2.1) into R1CS circuits. +// We could simply compute the result as a hint, then apply proper constraints, which is: +// for output of (b, y) +// +// b1 := {b = True AND y^2 * v = u} +// b2 := {b = False AND y^2 * v = Z * u} +// AssertTrue: {b1 OR b2} +func (m sswuMapper) sqrtRatio(u, v *fields_bls12381.E2) (frontend.Variable, *fields_bls12381.E2) { + // Steps + // 1. extract the base values of u, v, then compute G2SqrtRatio with gnark-crypto + x, err := m.fp.NewHint(GetHints()[0], 3, &u.A0, &u.A1, &v.A0, &v.A1) + if err != nil { + panic("failed to calculate sqrtRatio with gnark-crypto " + err.Error()) + } + + b := m.fp.IsZero(x[0]) + y := fields_bls12381.E2{A0: *x[1], A1: *x[2]} + + // 2. apply constraints + // b1 := {b = True AND y^2 * v = u} + m.api.AssertIsBoolean(b) + y2 := m.ext2.Square(&y) + y2v := m.ext2.Mul(y2, v) + bY2vu := m.ext2.IsZero(m.ext2.Sub(y2v, u)) + b1 := m.api.And(b, bY2vu) + + // b2 := {b = False AND y^2 * v = Z * u} + uZ := m.ext2.Mul(&m.Z, u) + bY2vZu := m.ext2.IsZero(m.ext2.Sub(y2v, uZ)) + nb := m.api.IsZero(b) + b2 := m.api.And(nb, bY2vZu) + + cmp := m.api.Or(b1, b2) + m.api.AssertIsEqual(cmp, 1) + + return b, &y +} + +type g2Polynomial []fields_bls12381.E2 + +func (p g2Polynomial) eval(m *sswuMapper, at fields_bls12381.E2) (pAt *fields_bls12381.E2) { + pAt = &p[len(p)-1] + + for i := len(p) - 2; i >= 0; i-- { + pAt = m.ext2.Mul(pAt, &at) + pAt = m.ext2.Add(pAt, &p[i]) + } + + return +} + +type isogeny struct { + x_numerator, x_denominator, y_numerator, y_denominator g2Polynomial +} + +func newIsogeny() *isogeny { + return &isogeny{ + x_numerator: g2Polynomial([]fields_bls12381.E2{ + *e2FromStrings( + "889424345604814976315064405719089812568196182208668418962679585805340366775741747653930584250892369786198727235542", + "889424345604814976315064405719089812568196182208668418962679585805340366775741747653930584250892369786198727235542"), + *e2FromStrings( + "0", + "2668273036814444928945193217157269437704588546626005256888038757416021100327225242961791752752677109358596181706522"), + *e2FromStrings( + "2668273036814444928945193217157269437704588546626005256888038757416021100327225242961791752752677109358596181706526", + "1334136518407222464472596608578634718852294273313002628444019378708010550163612621480895876376338554679298090853261"), + *e2FromStrings( + "3557697382419259905260257622876359250272784728834673675850718343221361467102966990615722337003569479144794908942033", + "0"), + }), + x_denominator: g2Polynomial([]fields_bls12381.E2{ + *e2FromStrings( + "0", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559715"), + *e2FromStrings( + "12", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559775"), + *e2FromStrings( + "1", + "0"), + }), + y_numerator: g2Polynomial([]fields_bls12381.E2{ + *e2FromStrings( + "3261222600550988246488569487636662646083386001431784202863158481286248011511053074731078808919938689216061999863558", + "3261222600550988246488569487636662646083386001431784202863158481286248011511053074731078808919938689216061999863558"), + *e2FromStrings( + "0", + "889424345604814976315064405719089812568196182208668418962679585805340366775741747653930584250892369786198727235518"), + *e2FromStrings( + "2668273036814444928945193217157269437704588546626005256888038757416021100327225242961791752752677109358596181706524", + "1334136518407222464472596608578634718852294273313002628444019378708010550163612621480895876376338554679298090853263"), + *e2FromStrings( + "2816510427748580758331037284777117739799287910327449993381818688383577828123182200904113516794492504322962636245776", + "0"), + }), + y_denominator: g2Polynomial([]fields_bls12381.E2{ + *e2FromStrings( + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559355", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559355"), + *e2FromStrings( + "0", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559571"), + *e2FromStrings( + "18", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559769"), + *e2FromStrings( + "1", + "0"), + }), + } +} + +// Map the point from E' to E +func (m sswuMapper) isogeny(x, y *fields_bls12381.E2) *G2Affine { + xn := m.iso.x_numerator.eval(&m, *x) + + xd := m.iso.x_denominator.eval(&m, *x) + xdInv := m.ext2.Inverse(xd) + + yn := m.iso.y_numerator.eval(&m, *x) + yn = m.ext2.Mul(yn, y) + + yd := m.iso.y_denominator.eval(&m, *x) + ydInv := m.ext2.Inverse(yd) + + return &G2Affine{ + P: g2AffP{ + X: *m.ext2.Mul(xn, xdInv), + Y: *m.ext2.Mul(yn, ydInv), + }, + } +} + +func e2FromStrings(x, y string) *fields_bls12381.E2 { + A0, _ := new(big.Int).SetString(x, 10) + A1, _ := new(big.Int).SetString(y, 10) + + a0 := emulated.ValueOf[emulated.BLS12381Fp](A0) + a1 := emulated.ValueOf[emulated.BLS12381Fp](A1) + + return &fields_bls12381.E2{A0: a0, A1: a1} +} + +// Follow RFC 9380 Apendix G.3 to compute efficiently. +func clearCofactor(g2 *G2, fp *emulated.Field[emparams.BLS12381Fp], p *G2Affine) *G2Affine { + // Steps: + // 1. t1 = c1 * P + // c1 = -15132376222941642752 + t1 := g2.scalarMulBySeed(p) + // 2. t2 = psi(P) + t2 := g2.psi(p) + // 3. t3 = 2 * P + t3 := g2.double(p) + // 4. t3 = psi2(t3) + t3 = g2.psi2(t3) + // 5. t3 = t3 - t2 + t3 = g2.sub(t3, t2) + // 6. t2 = t1 + t2 + t2 = g2.addUnified(t1, t2) + // 7. t2 = c1 * t2 + t2 = g2.scalarMulBySeed(t2) + // 8. t3 = t3 + t2 + t3 = g2.addUnified(t3, t2) + // 9. t3 = t3 - t1 + t3 = g2.sub(t3, t1) + // 10. Q = t3 - P + Q := g2.sub(t3, p) + // 11. return Q + return Q +} diff --git a/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go b/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go new file mode 100644 index 000000000..4c2e8250f --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go @@ -0,0 +1,167 @@ +package sw_bls12381 + +import ( + "encoding/hex" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + bls12381fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" + "github.com/consensys/gnark/std/hash/tofield" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/test" +) + +func getMsgs() []string { + return []string{"", "a", "ab", "abc", "abcd", "abcde", "abcdef", "abcdefg", "1", "2", "3", "4", "5"} +} + +func getDst() []byte { + dstHex := "412717974da474d0f8c420f320ff81e8432adb7c927d9bd082b4fb4d16c0a236" + dst := make([]byte, len(dstHex)/2) + hex.Decode(dst, []byte(dstHex)) + return dst +} + +type hashToFieldCircuit struct { + msg []byte + dst []byte +} + +func (c *hashToFieldCircuit) Define(api frontend.API) error { + msg := uints.NewU8Array(c.msg) + uniformBytes, _ := tofield.ExpandMsgXmd(api, msg, c.dst, 64) + fp, _ := emulated.NewField[emulated.BLS12381Fp](api) + + ele := bytesToElement(api, fp, uniformBytes) + + rawEles, _ := bls12381fp.Hash(c.msg, c.dst, 1) + wrappedEle := fp.NewElement(rawEles[0]) + + fp.AssertIsEqual(ele, wrappedEle) + + return nil +} + +func TestHashToFieldTestSolve(t *testing.T) { + assert := test.NewAssert(t) + + for _, msg := range getMsgs() { + + witness := hashToFieldCircuit{ + msg: []byte(msg), + dst: getDst(), + } + err := test.IsSolved(&hashToFieldCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + } +} + +type mapToCurveCircuit struct { + msg []byte + dst []byte +} + +func (c *mapToCurveCircuit) Define(api frontend.API) error { + msg := uints.NewU8Array(c.msg) + uniformBytes, _ := tofield.ExpandMsgXmd(api, msg, c.dst, 128) + fp, _ := emulated.NewField[emulated.BLS12381Fp](api) + ext2 := fields_bls12381.NewExt2(api) + mapper := newMapper(api, ext2, fp) + + ele1 := bytesToElement(api, fp, uniformBytes[:64]) + ele2 := bytesToElement(api, fp, uniformBytes[64:]) + e := fields_bls12381.E2{A0: *ele1, A1: *ele2} + affine := mapper.mapToCurve(e) + + rawEles, _ := bls12381fp.Hash(c.msg, c.dst, 2) + rawAffine := bls12381.MapToCurve2(&bls12381.E2{A0: rawEles[0], A1: rawEles[1]}) + wrappedRawAffine := NewG2Affine(rawAffine) + + g2 := NewG2(api) + g2.AssertIsEqual(affine, &wrappedRawAffine) + + return nil +} + +func TestMapToCurveTestSolve(t *testing.T) { + assert := test.NewAssert(t) + + for _, msg := range getMsgs() { + + witness := hashToFieldCircuit{ + msg: []byte(msg), + dst: getDst(), + } + err := test.IsSolved(&mapToCurveCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + } +} + +type clearCofactorCircuit struct { + In G2Affine + Res G2Affine +} + +func (c *clearCofactorCircuit) Define(api frontend.API) error { + g2 := NewG2(api) + fp, _ := emulated.NewField[emulated.BLS12381Fp](api) + res := clearCofactor(g2, fp, &c.In) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestClearCofactorTestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in := randomG1G2Affines() + + inAffine := NewG2Affine(in) + + in.ClearCofactor(&in) + circuit := clearCofactorCircuit{ + In: inAffine, + Res: NewG2Affine(in), + } + witness := clearCofactorCircuit{ + In: inAffine, + Res: NewG2Affine(in), + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type hashToG2Circuit struct { + msg []byte + dst []byte +} + +func (c *hashToG2Circuit) Define(api frontend.API) error { + res, e := HashToG2(api, uints.NewU8Array(c.msg), c.dst) + if e != nil { + return e + } + + expected, _ := bls12381.HashToG2(c.msg, c.dst) + wrappedRes := NewG2Affine(expected) + + g2 := NewG2(api) + g2.AssertIsEqual(res, &wrappedRes) + return nil +} + +func TestHashToG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + + for _, msg := range getMsgs() { + + witness := hashToG2Circuit{ + msg: []uint8(msg), + dst: getDst(), + } + err := test.IsSolved(&hashToG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + } +} diff --git a/std/algebra/emulated/sw_bls12381/hints.go b/std/algebra/emulated/sw_bls12381/hints.go new file mode 100644 index 000000000..6d6bbc443 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/hints.go @@ -0,0 +1,44 @@ +package sw_bls12381 + +import ( + "fmt" + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/math/emulated" +) + +func GetHints() []solver.Hint { + return []solver.Hint{ + sqrtRatioHint, + } +} + +func init() { + solver.RegisterHint(GetHints()...) +} + +func sqrtRatioHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 4 { + return fmt.Errorf("expecting 4 inputs") + } + if len(outputs) != 3 { + return fmt.Errorf("expecting 3 outputs") + } + + var z0, z1, u0, u1, v0, v1 fp.Element + u0.SetBigInt(inputs[0]) + u1.SetBigInt(inputs[1]) + v0.SetBigInt(inputs[2]) + v1.SetBigInt(inputs[3]) + + b := bls12381.G2SqrtRatio(&z0, &z1, &u0, &u1, &v0, &v1) + outputs[0].SetUint64(b) + z0.BigInt(outputs[1]) + z1.BigInt(outputs[2]) + return nil + }) +} diff --git a/std/hints.go b/std/hints.go index c95ef7b95..a8e6cf9a3 100644 --- a/std/hints.go +++ b/std/hints.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/algebra/emulated/sw_bls12381" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/algebra/native/sw_bls12377" "github.com/consensys/gnark/std/algebra/native/sw_bls24315" @@ -42,4 +43,5 @@ func registerHints() { solver.RegisterHint(logderivarg.GetHints()...) solver.RegisterHint(bitslice.GetHints()...) solver.RegisterHint(sw_emulated.GetHints()...) + solver.RegisterHint(sw_bls12381.GetHints()...) } From b0b11d2aa8a6ef575e363226b3371319e8daf62f Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Sun, 28 Jan 2024 20:10:36 +0800 Subject: [PATCH 03/16] revised G2.addUnified function based on the Brier and Joye algorithm --- std/algebra/emulated/sw_bls12381/g2.go | 78 ++++++++++++--------- std/algebra/emulated/sw_bls12381/g2_test.go | 55 ++++++++++++--- 2 files changed, 91 insertions(+), 42 deletions(-) diff --git a/std/algebra/emulated/sw_bls12381/g2.go b/std/algebra/emulated/sw_bls12381/g2.go index 91daab066..a6e612c38 100644 --- a/std/algebra/emulated/sw_bls12381/g2.go +++ b/std/algebra/emulated/sw_bls12381/g2.go @@ -150,46 +150,56 @@ func (g2 G2) add(p, q *G2Affine) *G2Affine { } } -// The add() function is not complete. If p == q, then λ is 0, while it should be (3p.x²)/2*p.y according to double() -// Moreover, the AssertIsEqual() check will fail within the DivUnchecked() call as the div hint will return 0. -// -// Provide addUnified to handle this situation at the cost of additional constraints. Useful when not certain if p != q. -// Note, ClearCofactor (covered in TestClearCofactorTestSolve of hash_to_g2_test.go) might fail without this function. +// Follow sw_emulated.Curve.AddUnified to implement the Brier and Joye algorithm +// to handle edge cases, i.e., p == q, p == 0 or/and q == 0 func (g2 G2) addUnified(p, q *G2Affine) *G2Affine { - // if p != q, compute λ = (q.y-p.y)/(q.x-p.x) - qypy := g2.Ext2.Sub(&q.P.Y, &p.P.Y) - qxpx := g2.Ext2.Sub(&q.P.X, &p.P.X) - - // else if p == q, compute λ = (3p.x²)/2*p.y - xx3a := g2.Square(&p.P.X) - xx3a = g2.MulByConstElement(xx3a, big.NewInt(3)) - y2 := g2.Double(&p.P.Y) - - z := g2.Ext2.IsZero(qxpx) - deltaX := g2.Select(z, y2, qxpx) - deltaY := g2.Select(z, xx3a, qypy) - λ := g2.Ext2.DivUnchecked(deltaY, deltaX) - - // if p == -q, result should zero - zeroConst := g2.Zero() - z1 := g2.Ext2.Add(&q.P.Y, &p.P.Y) - z2 := g2.Ext2.IsZero(z1) - z3 := g2.api.And(z, z2) - // xr = λ²-p.x-q.x - λλ := g2.Ext2.Square(λ) - qxpx = g2.Ext2.Add(&p.P.X, &q.P.X) - xr := g2.Ext2.Sub(λλ, qxpx) + // selector1 = 1 when p is (0,0) and 0 otherwise + selector1 := g2.api.And(g2.Ext2.IsZero(&p.P.X), g2.Ext2.IsZero(&p.P.Y)) + // selector2 = 1 when q is (0,0) and 0 otherwise + selector2 := g2.api.And(g2.Ext2.IsZero(&q.P.X), g2.Ext2.IsZero(&q.P.Y)) + + // λ = ((p.x+q.x)² - p.x*q.x + a)/(p.y + q.y) + pxqx := g2.Ext2.Mul(&p.P.X, &q.P.X) + pxplusqx := g2.Ext2.Add(&p.P.X, &q.P.X) + num := g2.Ext2.Mul(pxplusqx, pxplusqx) + num = g2.Ext2.Sub(num, pxqx) + denum := g2.Ext2.Add(&p.P.Y, &q.P.Y) + // if p.y + q.y = 0, assign dummy 1 to denum and continue + selector3 := g2.Ext2.IsZero(denum) + denum = g2.Ext2.Select(selector3, g2.Ext2.One(), denum) + λ := g2.Ext2.DivUnchecked(num, denum) // we already know that denum won't be zero + + // x = λ^2 - p.x - q.x + xr := g2.Ext2.Mul(λ, λ) + xr = g2.Ext2.Sub(xr, pxplusqx) + + // y = λ(p.x - xr) - p.y + yr := g2.Ext2.Sub(&p.P.X, xr) + yr = g2.Ext2.Mul(yr, λ) + yr = g2.Ext2.Sub(yr, &p.P.Y) + result := &G2Affine{ + P: g2AffP{ + X: *xr, + Y: *yr, + }, + } - // yr = λ(p.x-r.x) - p.y - pxrx := g2.Ext2.Sub(&p.P.X, xr) - λpxrx := g2.Ext2.Mul(λ, pxrx) - yr := g2.Ext2.Sub(λpxrx, &p.P.Y) + zero := g2.Ext2.Zero() + // if p=(0,0) return q + resultX := *g2.Select(selector1, &q.P.X, &result.P.X) + resultY := *g2.Select(selector1, &q.P.Y, &result.P.Y) + // if q=(0,0) return p + resultX = *g2.Select(selector2, &p.P.X, &resultX) + resultY = *g2.Select(selector2, &p.P.Y, &resultY) + // if p.y + q.y = 0, return (0, 0) + resultX = *g2.Select(selector3, zero, &resultX) + resultY = *g2.Select(selector3, zero, &resultY) return &G2Affine{ P: g2AffP{ - X: *g2.Select(z3, zeroConst, xr), - Y: *g2.Select(z3, zeroConst, yr), + X: resultX, + Y: resultY, }, } } diff --git a/std/algebra/emulated/sw_bls12381/g2_test.go b/std/algebra/emulated/sw_bls12381/g2_test.go index 467925bea..534843ccd 100644 --- a/std/algebra/emulated/sw_bls12381/g2_test.go +++ b/std/algebra/emulated/sw_bls12381/g2_test.go @@ -93,19 +93,58 @@ func TestAddG2UnifiedTestSolveDbl(t *testing.T) { assert.NoError(err) } -func TestAddG2UnifiedTestSolveNeg(t *testing.T) { +func TestAddG2UnifiedTestSolveEdgeCases(t *testing.T) { assert := test.NewAssert(t) - _, in1 := randomG1G2Affines() - var in2, res bls12381.G2Affine - in2.Neg(&in1) - res.Add(&in1, &in2) + _, p := randomG1G2Affines() + var np, zero bls12381.G2Affine + np.Neg(&p) + zero.Sub(&p, &p) + + // p + (-p) == (0, 0) witness := addG2UnifiedCircuit{ - In1: NewG2Affine(in1), - In2: NewG2Affine(in2), - Res: NewG2Affine(res), + In1: NewG2Affine(p), + In2: NewG2Affine(np), + Res: NewG2Affine(zero), } err := test.IsSolved(&addG2UnifiedCircuit{}, &witness, ecc.BN254.ScalarField()) assert.NoError(err) + + // (-p) + p == (0, 0) + witness2 := addG2UnifiedCircuit{ + In1: NewG2Affine(np), + In2: NewG2Affine(p), + Res: NewG2Affine(zero), + } + err2 := test.IsSolved(&addG2UnifiedCircuit{}, &witness2, ecc.BN254.ScalarField()) + assert.NoError(err2) + + // p + (0, 0) == p + witness3 := addG2UnifiedCircuit{ + In1: NewG2Affine(p), + In2: NewG2Affine(zero), + Res: NewG2Affine(p), + } + err3 := test.IsSolved(&addG2UnifiedCircuit{}, &witness3, ecc.BN254.ScalarField()) + assert.NoError(err3) + + // (0, 0) + p == p + witness4 := addG2UnifiedCircuit{ + In1: NewG2Affine(zero), + In2: NewG2Affine(p), + Res: NewG2Affine(p), + } + err4 := test.IsSolved(&addG2UnifiedCircuit{}, &witness4, ecc.BN254.ScalarField()) + assert.NoError(err4) + + // (0, 0) + (0, 0) == (0, 0) + witness5 := addG2UnifiedCircuit{ + In1: NewG2Affine(zero), + In2: NewG2Affine(zero), + Res: NewG2Affine(zero), + } + err5 := test.IsSolved(&addG2UnifiedCircuit{}, &witness5, ecc.BN254.ScalarField()) + assert.NoError(err5) + } type doubleG2Circuit struct { From b4da7688710bd8400c8c53d8f744edfdc3e77eba Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Mon, 29 Jan 2024 19:59:20 +0800 Subject: [PATCH 04/16] fixed G2.sgn0 function and associated unit tests for BLS12-381 --- .../emulated/sw_bls12381/hash_to_g2.go | 6 +- .../emulated/sw_bls12381/hash_to_g2_test.go | 93 ++++++++++++------- 2 files changed, 62 insertions(+), 37 deletions(-) diff --git a/std/algebra/emulated/sw_bls12381/hash_to_g2.go b/std/algebra/emulated/sw_bls12381/hash_to_g2.go index 10266f383..55d8d9a52 100644 --- a/std/algebra/emulated/sw_bls12381/hash_to_g2.go +++ b/std/algebra/emulated/sw_bls12381/hash_to_g2.go @@ -184,12 +184,14 @@ func (m sswuMapper) mapToCurve(u fields_bls12381.E2) *G2Affine { func (m sswuMapper) sgn0(x *fields_bls12381.E2) frontend.Variable { // Steps for sgn0_m_eq_2 // 1. sign_0 = x_0 mod 2 - x0 := m.fp.ToBits(&x.A0) + A0 := m.fp.Reduce(&x.A0) + x0 := m.fp.ToBits(A0) sign0 := x0[0] // 2. zero_0 = x_0 == 0 zero0 := m.fp.IsZero(&x.A0) // 3. sign_1 = x_1 mod 2 - x1 := m.fp.ToBits(&x.A1) + A1 := m.fp.Reduce(&x.A1) + x1 := m.fp.ToBits(A1) sign1 := x1[0] // 4. s = sign_0 OR (zero_0 AND sign_1) # Avoid short-circuit logic ops tv := m.api.And(zero0, sign1) diff --git a/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go b/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go index 4c2e8250f..a0dbc6e52 100644 --- a/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go +++ b/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go @@ -16,7 +16,7 @@ import ( ) func getMsgs() []string { - return []string{"", "a", "ab", "abc", "abcd", "abcde", "abcdef", "abcdefg", "1", "2", "3", "4", "5"} + return []string{"", "a", "ab", "abc", "abcd", "abcde", "abcdef", "abcdefg", "1", "2", "3", "4", "5", "5656565656565656565656565656565656565656565656565656565656565656"} } func getDst() []byte { @@ -27,76 +27,91 @@ func getDst() []byte { } type hashToFieldCircuit struct { - msg []byte - dst []byte + Msg []byte + Dst []byte + Res bls12381fp.Element } func (c *hashToFieldCircuit) Define(api frontend.API) error { - msg := uints.NewU8Array(c.msg) - uniformBytes, _ := tofield.ExpandMsgXmd(api, msg, c.dst, 64) + msg := uints.NewU8Array(c.Msg) + uniformBytes, _ := tofield.ExpandMsgXmd(api, msg, c.Dst, 64) fp, _ := emulated.NewField[emulated.BLS12381Fp](api) ele := bytesToElement(api, fp, uniformBytes) - rawEles, _ := bls12381fp.Hash(c.msg, c.dst, 1) - wrappedEle := fp.NewElement(rawEles[0]) - - fp.AssertIsEqual(ele, wrappedEle) + fp.AssertIsEqual(ele, fp.NewElement(c.Res)) return nil } func TestHashToFieldTestSolve(t *testing.T) { assert := test.NewAssert(t) + dst := getDst() for _, msg := range getMsgs() { + rawEles, _ := bls12381fp.Hash([]byte(msg), dst, 1) + + circuit := hashToFieldCircuit{ + Msg: []byte(msg), + Dst: dst, + Res: rawEles[0], + } witness := hashToFieldCircuit{ - msg: []byte(msg), - dst: getDst(), + Msg: []byte(msg), + Dst: dst, + Res: rawEles[0], } - err := test.IsSolved(&hashToFieldCircuit{}, &witness, ecc.BN254.ScalarField()) + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) assert.NoError(err) } } type mapToCurveCircuit struct { - msg []byte - dst []byte + Msg []byte + Dst []byte + Res G2Affine } func (c *mapToCurveCircuit) Define(api frontend.API) error { - msg := uints.NewU8Array(c.msg) - uniformBytes, _ := tofield.ExpandMsgXmd(api, msg, c.dst, 128) + msg := uints.NewU8Array(c.Msg) fp, _ := emulated.NewField[emulated.BLS12381Fp](api) ext2 := fields_bls12381.NewExt2(api) mapper := newMapper(api, ext2, fp) + uniformBytes, _ := tofield.ExpandMsgXmd(api, msg, c.Dst, 128) ele1 := bytesToElement(api, fp, uniformBytes[:64]) ele2 := bytesToElement(api, fp, uniformBytes[64:]) e := fields_bls12381.E2{A0: *ele1, A1: *ele2} affine := mapper.mapToCurve(e) - rawEles, _ := bls12381fp.Hash(c.msg, c.dst, 2) - rawAffine := bls12381.MapToCurve2(&bls12381.E2{A0: rawEles[0], A1: rawEles[1]}) - wrappedRawAffine := NewG2Affine(rawAffine) - g2 := NewG2(api) - g2.AssertIsEqual(affine, &wrappedRawAffine) + g2.AssertIsEqual(affine, &c.Res) return nil } func TestMapToCurveTestSolve(t *testing.T) { assert := test.NewAssert(t) + dst := getDst() for _, msg := range getMsgs() { - witness := hashToFieldCircuit{ - msg: []byte(msg), - dst: getDst(), + rawEles, _ := bls12381fp.Hash([]byte(msg), dst, 2) + rawAffine := bls12381.MapToCurve2(&bls12381.E2{A0: rawEles[0], A1: rawEles[1]}) + wrappedRawAffine := NewG2Affine(rawAffine) + + circuit := mapToCurveCircuit{ + Msg: []byte(msg), + Dst: dst, + Res: wrappedRawAffine, } - err := test.IsSolved(&mapToCurveCircuit{}, &witness, ecc.BN254.ScalarField()) + witness := mapToCurveCircuit{ + Msg: []byte(msg), + Dst: dst, + Res: wrappedRawAffine, + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) assert.NoError(err) } } @@ -134,34 +149,42 @@ func TestClearCofactorTestSolve(t *testing.T) { } type hashToG2Circuit struct { - msg []byte - dst []byte + Msg []byte + Dst []byte + Res G2Affine } func (c *hashToG2Circuit) Define(api frontend.API) error { - res, e := HashToG2(api, uints.NewU8Array(c.msg), c.dst) + res, e := HashToG2(api, uints.NewU8Array(c.Msg), c.Dst) if e != nil { return e } - expected, _ := bls12381.HashToG2(c.msg, c.dst) - wrappedRes := NewG2Affine(expected) - g2 := NewG2(api) - g2.AssertIsEqual(res, &wrappedRes) + g2.AssertIsEqual(res, &c.Res) return nil } func TestHashToG2TestSolve(t *testing.T) { assert := test.NewAssert(t) + dst := getDst() for _, msg := range getMsgs() { + expected, _ := bls12381.HashToG2([]uint8(msg), dst) + wrappedRes := NewG2Affine(expected) + + circuit := hashToG2Circuit{ + Msg: []uint8(msg), + Dst: dst, + Res: wrappedRes, + } witness := hashToG2Circuit{ - msg: []uint8(msg), - dst: getDst(), + Msg: []uint8(msg), + Dst: dst, + Res: wrappedRes, } - err := test.IsSolved(&hashToG2Circuit{}, &witness, ecc.BN254.ScalarField()) + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) assert.NoError(err) } } From 86d12ce7661a4261159f000f6e5de8d2760c3e69 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Mon, 29 Jan 2024 20:23:05 +0800 Subject: [PATCH 05/16] implemented BLS signature verification for BLS12-381/G2 --- std/algebra/emulated/sw_bls12381/bls_sig.go | 34 ++++++++ .../emulated/sw_bls12381/bls_sig_test.go | 83 +++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 std/algebra/emulated/sw_bls12381/bls_sig.go create mode 100644 std/algebra/emulated/sw_bls12381/bls_sig_test.go diff --git a/std/algebra/emulated/sw_bls12381/bls_sig.go b/std/algebra/emulated/sw_bls12381/bls_sig.go new file mode 100644 index 000000000..410da6be1 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/bls_sig.go @@ -0,0 +1,34 @@ +package sw_bls12381 + +import ( + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" +) + +const g2_dst = "BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_" + +func BlsAssertG2Verification(api frontend.API, pub G1Affine, sig G2Affine, msg []uints.U8) error { + pairing, e := NewPairing(api) + if e != nil { + return e + } + + // public key cannot be infinity + xtest := pairing.g1.curveF.IsZero(&pub.X) + ytest := pairing.g1.curveF.IsZero(&pub.Y) + pubTest := api.Or(xtest, ytest) + api.AssertIsEqual(pubTest, 0) + + // prime order subgroup checks + pairing.AssertIsOnG1(&pub) + pairing.AssertIsOnG2(&sig) + + var g1GNeg bls12381.G1Affine + _, _, g1Gen, _ := bls12381.Generators() + g1GNeg.Neg(&g1Gen) + g1GN := NewG1Affine(g1GNeg) + + h, e := HashToG2(api, msg, []byte(g2_dst)) + return pairing.PairingCheck([]*G1Affine{&g1GN, &pub}, []*G2Affine{&sig, h}) +} diff --git a/std/algebra/emulated/sw_bls12381/bls_sig_test.go b/std/algebra/emulated/sw_bls12381/bls_sig_test.go new file mode 100644 index 000000000..47827ab25 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/bls_sig_test.go @@ -0,0 +1,83 @@ +package sw_bls12381 + +import ( + "encoding/hex" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/test" +) + +type blsG2SigCircuit struct { + Pub bls12381.G1Affine + msg []byte + Sig bls12381.G2Affine +} + +func (c *blsG2SigCircuit) Define(api frontend.API) error { + msg := uints.NewU8Array(c.msg) + return BlsAssertG2Verification(api, NewG1Affine(c.Pub), NewG2Affine(c.Sig), msg) +} + +// "pubkey": "0xa491d1b0ecd9bb917989f0e74f0dea0422eac4a873e5e2644f368dffb9a6e20fd6e10c1b77654d067c0618f6e5a7f79a", +// "message": "0x5656565656565656565656565656565656565656565656565656565656565656", +// "signature": "0x882730e5d03f6b42c3abc26d3372625034e1d871b65a8a6b900a56dae22da98abbe1b68f85e49fe7652a55ec3d0591c20767677e33e5cbb1207315c41a9ac03be39c2e7668edc043d6cb1d9fd93033caa8a1c5b0e84bedaeb6c64972503a43eb"}, +// "output": true} +func TestBlsSigTestSolve(t *testing.T) { + assert := test.NewAssert(t) + + msgHex := "5656565656565656565656565656565656565656565656565656565656565656" + pubHex := "a491d1b0ecd9bb917989f0e74f0dea0422eac4a873e5e2644f368dffb9a6e20fd6e10c1b77654d067c0618f6e5a7f79a" + sigHex := "882730e5d03f6b42c3abc26d3372625034e1d871b65a8a6b900a56dae22da98abbe1b68f85e49fe7652a55ec3d0591c20767677e33e5cbb1207315c41a9ac03be39c2e7668edc043d6cb1d9fd93033caa8a1c5b0e84bedaeb6c64972503a43eb" + + msgBytes := make([]byte, len(msgHex)>>1) + hex.Decode(msgBytes, []byte(msgHex)) + pubBytes := make([]byte, len(pubHex)>>1) + hex.Decode(pubBytes, []byte(pubHex)) + sigBytes := make([]byte, len(sigHex)>>1) + hex.Decode(sigBytes, []byte(sigHex)) + + var pub bls12381.G1Affine + _, e := pub.SetBytes(pubBytes) + if e != nil { + t.Fail() + } + var sig bls12381.G2Affine + _, e = sig.SetBytes(sigBytes) + if e != nil { + t.Fail() + } + + var g1GNeg bls12381.G1Affine + _, _, g1Gen, _ := bls12381.Generators() + g1GNeg.Neg(&g1Gen) + + h, e := bls12381.HashToG2(msgBytes, []byte(g2_dst)) + if e != nil { + t.Fail() + } + + b, e := bls12381.PairingCheck([]bls12381.G1Affine{g1GNeg, pub}, []bls12381.G2Affine{sig, h}) + if e != nil { + t.Fail() + } + if !b { + t.Fail() // invalid inputs, won't verify + } + + circuit := blsG2SigCircuit{ + Pub: pub, + msg: msgBytes, + Sig: sig, + } + witness := blsG2SigCircuit{ + Pub: pub, + msg: msgBytes, + Sig: sig, + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} From 78bdcf6a20aadf03a659233088bed182c6d91333 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Tue, 30 Jan 2024 20:46:22 +0800 Subject: [PATCH 06/16] added benchmarks for HashToG2/BLS12-381 --- .../emulated/sw_bls12381/hash_to_g2_test.go | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go b/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go index a0dbc6e52..683106b5b 100644 --- a/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go +++ b/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go @@ -1,13 +1,17 @@ package sw_bls12381 import ( + "bytes" "encoding/hex" "testing" "github.com/consensys/gnark-crypto/ecc" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" bls12381fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" "github.com/consensys/gnark/std/hash/tofield" "github.com/consensys/gnark/std/math/emulated" @@ -188,3 +192,75 @@ func TestHashToG2TestSolve(t *testing.T) { assert.NoError(err) } } + +type hashToG2BenchCircuit struct { + Msg []byte + Dst []byte +} + +func (c *hashToG2BenchCircuit) Define(api frontend.API) error { + _, e := HashToG2(api, uints.NewU8Array(c.Msg), c.Dst) + return e +} + +func BenchmarkHashToG2(b *testing.B) { + + dst := getDst() + + msg := "abcd" + witness := hashToG2BenchCircuit{ + Msg: []uint8(msg), + Dst: dst, + } + w, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + b.Fatal(err) + } + var ccs constraint.ConstraintSystem + b.Run("compile scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &hashToG2BenchCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + var buf bytes.Buffer + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("scs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + b.Run("solve scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) + b.Run("compile r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &hashToG2BenchCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + buf.Reset() + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("r1cs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + + b.Run("solve r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) + +} From 8e2fc0eb790995cbab894df2a39c2a2530f02ec9 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Sat, 10 Feb 2024 18:02:38 +0800 Subject: [PATCH 07/16] gofmt --- std/hints.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/hints.go b/std/hints.go index 54ed85932..33149e9ef 100644 --- a/std/hints.go +++ b/std/hints.go @@ -4,10 +4,10 @@ import ( "sync" "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/std/algebra/emulated/sw_bls12381" "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" "github.com/consensys/gnark/std/algebra/emulated/fields_bn254" "github.com/consensys/gnark/std/algebra/emulated/fields_bw6761" + "github.com/consensys/gnark/std/algebra/emulated/sw_bls12381" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/algebra/native/fields_bls12377" "github.com/consensys/gnark/std/algebra/native/fields_bls24315" From 9afd6f2e65ab1711cd715baa5a9a5f0af4d116a0 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Thu, 15 Feb 2024 22:26:19 +0800 Subject: [PATCH 08/16] vkey fingerprint --- std/recursion/plonk/verifier.go | 53 ++++++++++++++++++++++++++++ std/recursion/plonk/verifier_test.go | 5 +++ 2 files changed, 58 insertions(+) diff --git a/std/recursion/plonk/verifier.go b/std/recursion/plonk/verifier.go index 67c776409..ba0dbb335 100644 --- a/std/recursion/plonk/verifier.go +++ b/std/recursion/plonk/verifier.go @@ -25,6 +25,7 @@ import ( "github.com/consensys/gnark/std/algebra/native/sw_bls24315" "github.com/consensys/gnark/std/commitments/kzg" fiatshamir "github.com/consensys/gnark/std/fiat-shamir" + "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/recursion" @@ -310,6 +311,58 @@ type VerifyingKey[FR emulated.FieldParams, G1El algebra.G1ElementT, G2El algebra CircuitVerifyingKey[FR, G1El] } +// FingerPrint() returns the MiMc hash of the VerifyingKey. It could be used to identify a VerifyingKey +// during recursive verification. +func (vk *VerifyingKey[FR, G1El, G2El]) FingerPrint(api frontend.API) (frontend.Variable, error) { + var ret frontend.Variable + mimc, err := mimc.NewMiMC(api) + if err != nil { + return ret, err + } + + mimc.Write(vk.BaseVerifyingKey.NbPublicVariables) + mimc.Write(vk.CircuitVerifyingKey.Size) + mimc.Write(vk.CircuitVerifyingKey.Generator.Limbs[:]...) + + comms := make([]kzg.Commitment[G1El], 0) + comms = append(comms, vk.CircuitVerifyingKey.S[:]...) + comms = append(comms, vk.CircuitVerifyingKey.Ql) + comms = append(comms, vk.CircuitVerifyingKey.Qr) + comms = append(comms, vk.CircuitVerifyingKey.Qm) + comms = append(comms, vk.CircuitVerifyingKey.Qo) + comms = append(comms, vk.CircuitVerifyingKey.Qk) + comms = append(comms, vk.CircuitVerifyingKey.Qcp[:]...) + + for _, comm := range comms { + el := comm.G1El + switch r := any(&el).(type) { + case *sw_bls12377.G1Affine: + mimc.Write(r.X) + mimc.Write(r.Y) + case *sw_bls12381.G1Affine: + mimc.Write(r.X.Limbs[:]...) + mimc.Write(r.Y.Limbs[:]...) + case *sw_bls24315.G1Affine: + mimc.Write(r.X) + mimc.Write(r.Y) + case *sw_bw6761.G1Affine: + mimc.Write(r.X.Limbs[:]...) + mimc.Write(r.Y.Limbs[:]...) + case *sw_bn254.G1Affine: + mimc.Write(r.X.Limbs[:]...) + mimc.Write(r.Y.Limbs[:]...) + default: + return ret, fmt.Errorf("unknown parametric type") + } + } + + mimc.Write(vk.CircuitVerifyingKey.CommitmentConstraintIndexes[:]...) + + result := mimc.Sum() + + return result, nil +} + // ValueOfBaseVerifyingKey assigns the base verification key from the witness. // Use one of the verifiaction keys for the same-sized circuits. func ValueOfBaseVerifyingKey[FR emulated.FieldParams, G1El algebra.G1ElementT, G2El algebra.G2ElementT](vk backend_plonk.VerifyingKey) (BaseVerifyingKey[FR, G1El, G2El], error) { diff --git a/std/recursion/plonk/verifier_test.go b/std/recursion/plonk/verifier_test.go index f621a83f3..195198734 100644 --- a/std/recursion/plonk/verifier_test.go +++ b/std/recursion/plonk/verifier_test.go @@ -37,6 +37,11 @@ func (c *OuterCircuit[FR, G1El, G2El, GtEl]) Define(api frontend.API) error { if err != nil { return fmt.Errorf("new verifier: %w", err) } + fp, err := c.VerifyingKey.FingerPrint(api) + if err != nil { + return fmt.Errorf("new curve for verification keys: %w", err) + } + api.Println(fp) err = verifier.AssertProof(c.VerifyingKey, c.Proof, c.InnerWitness) return err } From be6e2330a873c61a8b6c80cafac19c9c1b2558ac Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Mon, 19 Feb 2024 12:26:10 +0800 Subject: [PATCH 09/16] veky fingerprint -- furhter tests --- std/recursion/plonk/vkey_fp_test.go | 204 ++++++++++++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 std/recursion/plonk/vkey_fp_test.go diff --git a/std/recursion/plonk/vkey_fp_test.go b/std/recursion/plonk/vkey_fp_test.go new file mode 100644 index 000000000..31c04b6eb --- /dev/null +++ b/std/recursion/plonk/vkey_fp_test.go @@ -0,0 +1,204 @@ +package plonk + +import ( + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + native_plonk "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/std/algebra" + "github.com/consensys/gnark/std/algebra/emulated/sw_bw6761" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" + "github.com/consensys/gnark/test/unsafekzg" +) + +type OuterCircuitDual[FR emulated.FieldParams, G1El algebra.G1ElementT, G2El algebra.G2ElementT, GtEl algebra.GtElementT] struct { + Proofs []Proof[FR, G1El, G2El] + VerifyingKeys []VerifyingKey[FR, G1El, G2El] `gnark:"-"` + InnerWitnesses []Witness[FR] `gnark:",public"` +} + +func (c *OuterCircuitDual[FR, G1El, G2El, GtEl]) Define(api frontend.API) error { + verifier, err := NewVerifier[FR, G1El, G2El, GtEl](api) + if err != nil { + return fmt.Errorf("new verifier: %w", err) + } + fp, err := c.VerifyingKeys[0].FingerPrint(api) + if err != nil { + return fmt.Errorf("new curve for verification keys: %w", err) + } + api.Println(fp) + err = verifier.AssertProof(c.VerifyingKeys[0], c.Proofs[0], c.InnerWitnesses[0], WithCompleteArithmetic()) + + fp2, err := c.VerifyingKeys[1].FingerPrint(api) + if err != nil { + return fmt.Errorf("new curve for verification keys: %w", err) + } + api.Println(fp2) + // err = verifier.AssertProof(c.VerifyingKeys[1], c.Proofs[1], c.InnerWitnesses[1], WithCompleteArithmetic()) + // same constant value should result same verification key + api.AssertIsEqual(fp, fp2) + + fp3, err := c.VerifyingKeys[2].FingerPrint(api) + if err != nil { + return fmt.Errorf("new curve for verification keys: %w", err) + } + api.Println(fp3) + // err = verifier.AssertProof(c.VerifyingKeys[2], c.Proofs[2], c.InnerWitnesses[2], WithCompleteArithmetic()) + // different constant value should result different verification key + api.AssertIsDifferent(fp, fp3) + + return err +} + +// the constant value (c.multiplier) should impact not only the relationship between X and Y +// but also the circuit structure, meaning the vkey fingerprint will *change* with a different constant value +type InnerCircuitWithConstant struct { + X frontend.Variable + Y frontend.Variable `gnark:",public"` + multiplier int +} + +func (c *InnerCircuitWithConstant) Define(api frontend.API) error { + res := api.Mul(c.X, c.multiplier) + api.AssertIsEqual(res, c.Y) + + return nil +} + +func getInnerCircuitProof(assert *test.Assert, field, outer *big.Int) ([]constraint.ConstraintSystem, []native_plonk.VerifyingKey, []witness.Witness, []native_plonk.Proof) { + + innerCcs, err := frontend.Compile(field, scs.NewBuilder, &InnerCircuitWithConstant{multiplier: 5}) + assert.NoError(err) + + srs, srsLagrange, err := unsafekzg.NewSRS(innerCcs) + assert.NoError(err) + + innerPK, innerVK, err := native_plonk.Setup(innerCcs, srs, srsLagrange) + assert.NoError(err) + + // inner proof + innerAssignment := &InnerCircuitWithConstant{ + X: 3, + Y: 15, + } + innerWitness, err := frontend.NewWitness(innerAssignment, field) + assert.NoError(err) + innerProof, err := native_plonk.Prove(innerCcs, innerPK, innerWitness, GetNativeProverOptions(outer, field)) + + assert.NoError(err) + innerPubWitness, err := innerWitness.Public() + assert.NoError(err) + err = native_plonk.Verify(innerProof, innerVK, innerPubWitness, GetNativeVerifierOptions(outer, field)) + + assert.NoError(err) + + // innerCcs is only needed for nbConstraints/nbPublicVarialbes and .Field() + // so we could reuse generated srs for another CCS instance which only differs in the constant multiplier + innerCcs2, err := frontend.Compile(field, scs.NewBuilder, &InnerCircuitWithConstant{multiplier: 5}) + assert.NoError(err) + + innerPK2, innerVK2, err := native_plonk.Setup(innerCcs2, srs, srsLagrange) + assert.NoError(err) + + // inner proof2 + innerAssignment2 := &InnerCircuitWithConstant{ + X: 3, + Y: 15, + } + innerWitness2, err := frontend.NewWitness(innerAssignment2, field) + assert.NoError(err) + innerProof2, err := native_plonk.Prove(innerCcs2, innerPK2, innerWitness2, GetNativeProverOptions(outer, field)) + + assert.NoError(err) + innerPubWitness2, err := innerWitness2.Public() + assert.NoError(err) + err = native_plonk.Verify(innerProof2, innerVK2, innerPubWitness2, GetNativeVerifierOptions(outer, field)) + + assert.NoError(err) + + innerCcs3, err := frontend.Compile(field, scs.NewBuilder, &InnerCircuitWithConstant{multiplier: 7}) + assert.NoError(err) + + innerPK3, innerVK3, err := native_plonk.Setup(innerCcs3, srs, srsLagrange) + assert.NoError(err) + + // inner proof3 + innerAssignment3 := &InnerCircuitWithConstant{ + X: 3, + Y: 21, + } + innerWitness3, err := frontend.NewWitness(innerAssignment3, field) + assert.NoError(err) + innerProof3, err := native_plonk.Prove(innerCcs3, innerPK3, innerWitness3, GetNativeProverOptions(outer, field)) + + assert.NoError(err) + innerPubWitness3, err := innerWitness3.Public() + assert.NoError(err) + err = native_plonk.Verify(innerProof3, innerVK3, innerPubWitness3, GetNativeVerifierOptions(outer, field)) + + assert.NoError(err) + + return []constraint.ConstraintSystem{innerCcs, innerCcs2, innerCcs3}, + []native_plonk.VerifyingKey{innerVK, innerVK2, innerVK3}, + []witness.Witness{innerPubWitness, innerPubWitness2, innerPubWitness3}, + []native_plonk.Proof{innerProof, innerProof2, innerProof3} +} + +func TestBW6InBN254VkeyFp(t *testing.T) { + + assert := test.NewAssert(t) + innerCcs, innerVK, innerWitness, innerProof := getInnerCircuitProof(assert, ecc.BW6_761.ScalarField(), ecc.BN254.ScalarField()) + + // outer proofs + circuitVk, err := ValueOfVerifyingKey[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine](innerVK[0]) + assert.NoError(err) + circuitWitness, err := ValueOfWitness[sw_bw6761.ScalarField](innerWitness[0]) + assert.NoError(err) + circuitProof, err := ValueOfProof[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine](innerProof[0]) + assert.NoError(err) + + circuitVk2, err := ValueOfVerifyingKey[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine](innerVK[1]) + assert.NoError(err) + circuitWitness2, err := ValueOfWitness[sw_bw6761.ScalarField](innerWitness[1]) + assert.NoError(err) + circuitProof2, err := ValueOfProof[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine](innerProof[1]) + assert.NoError(err) + + circuitVk3, err := ValueOfVerifyingKey[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine](innerVK[2]) + assert.NoError(err) + circuitWitness3, err := ValueOfWitness[sw_bw6761.ScalarField](innerWitness[2]) + assert.NoError(err) + circuitProof3, err := ValueOfProof[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine](innerProof[2]) + assert.NoError(err) + + outerCircuit := &OuterCircuitDual[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine, sw_bw6761.GTEl]{ + InnerWitnesses: []Witness[sw_bw6761.ScalarField]{ + PlaceholderWitness[sw_bw6761.ScalarField](innerCcs[0]), + PlaceholderWitness[sw_bw6761.ScalarField](innerCcs[1]), + PlaceholderWitness[sw_bw6761.ScalarField](innerCcs[2]), + }, + Proofs: []Proof[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine]{ + PlaceholderProof[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine](innerCcs[0]), + PlaceholderProof[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine](innerCcs[1]), + PlaceholderProof[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine](innerCcs[2]), + }, + VerifyingKeys: []VerifyingKey[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine]{ + circuitVk, + circuitVk2, + circuitVk3, + }, + } + outerAssignment := &OuterCircuitDual[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine, sw_bw6761.GTEl]{ + InnerWitnesses: []Witness[sw_bw6761.ScalarField]{circuitWitness, circuitWitness2, circuitWitness3}, + Proofs: []Proof[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine]{circuitProof, circuitProof2, circuitProof3}, + } + err = test.IsSolved(outerCircuit, outerAssignment, ecc.BN254.ScalarField()) + assert.NoError(err) +} From 4a2f9b0ee1b7956e2bd74c88bb37fe7d405a07a6 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Mon, 19 Feb 2024 14:53:46 +0800 Subject: [PATCH 10/16] golangci-lint --- std/algebra/emulated/sw_bls12381/bls_sig.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/std/algebra/emulated/sw_bls12381/bls_sig.go b/std/algebra/emulated/sw_bls12381/bls_sig.go index 410da6be1..296948539 100644 --- a/std/algebra/emulated/sw_bls12381/bls_sig.go +++ b/std/algebra/emulated/sw_bls12381/bls_sig.go @@ -30,5 +30,9 @@ func BlsAssertG2Verification(api frontend.API, pub G1Affine, sig G2Affine, msg [ g1GN := NewG1Affine(g1GNeg) h, e := HashToG2(api, msg, []byte(g2_dst)) + if e != nil { + return e + } + return pairing.PairingCheck([]*G1Affine{&g1GN, &pub}, []*G2Affine{&sig, h}) } From de5536a609d6e010a35677ee18e2f0ccbd249a77 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Mon, 19 Feb 2024 14:58:10 +0800 Subject: [PATCH 11/16] golangci-lint --- std/recursion/plonk/vkey_fp_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/std/recursion/plonk/vkey_fp_test.go b/std/recursion/plonk/vkey_fp_test.go index 31c04b6eb..cec7bf981 100644 --- a/std/recursion/plonk/vkey_fp_test.go +++ b/std/recursion/plonk/vkey_fp_test.go @@ -35,13 +35,15 @@ func (c *OuterCircuitDual[FR, G1El, G2El, GtEl]) Define(api frontend.API) error } api.Println(fp) err = verifier.AssertProof(c.VerifyingKeys[0], c.Proofs[0], c.InnerWitnesses[0], WithCompleteArithmetic()) + if err != nil { + return err + } fp2, err := c.VerifyingKeys[1].FingerPrint(api) if err != nil { return fmt.Errorf("new curve for verification keys: %w", err) } api.Println(fp2) - // err = verifier.AssertProof(c.VerifyingKeys[1], c.Proofs[1], c.InnerWitnesses[1], WithCompleteArithmetic()) // same constant value should result same verification key api.AssertIsEqual(fp, fp2) @@ -50,7 +52,6 @@ func (c *OuterCircuitDual[FR, G1El, G2El, GtEl]) Define(api frontend.API) error return fmt.Errorf("new curve for verification keys: %w", err) } api.Println(fp3) - // err = verifier.AssertProof(c.VerifyingKeys[2], c.Proofs[2], c.InnerWitnesses[2], WithCompleteArithmetic()) // different constant value should result different verification key api.AssertIsDifferent(fp, fp3) From 4f3ae1ded94972c638be01bf480bf1335551acd9 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Mon, 19 Feb 2024 21:28:52 +0800 Subject: [PATCH 12/16] typo --- std/recursion/plonk/vkey_fp_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/recursion/plonk/vkey_fp_test.go b/std/recursion/plonk/vkey_fp_test.go index cec7bf981..0e508bded 100644 --- a/std/recursion/plonk/vkey_fp_test.go +++ b/std/recursion/plonk/vkey_fp_test.go @@ -52,7 +52,7 @@ func (c *OuterCircuitDual[FR, G1El, G2El, GtEl]) Define(api frontend.API) error return fmt.Errorf("new curve for verification keys: %w", err) } api.Println(fp3) - // different constant value should result different verification key + // different constant values should result in different verification keys api.AssertIsDifferent(fp, fp3) return err From 9480d57639c092e87d44e67633ada7eb177215fa Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Thu, 29 Feb 2024 22:33:45 +0800 Subject: [PATCH 13/16] added type conversion from frontend.Variable to []U8 --- std/math/uints/uint8.go | 30 +++++++++++++++++++++++++++++ std/math/uints/uint8_test.go | 37 ++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/std/math/uints/uint8.go b/std/math/uints/uint8.go index cec591d10..406893294 100644 --- a/std/math/uints/uint8.go +++ b/std/math/uints/uint8.go @@ -24,9 +24,11 @@ package uints import ( "fmt" + "math" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/internal/logderivprecomp" + "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/bitslice" "github.com/consensys/gnark/std/rangecheck" ) @@ -167,6 +169,34 @@ func (bf *BinaryField[T]) ByteValueOf(a frontend.Variable) U8 { return U8{Val: a, internal: true} } +// Convert any varialbe to bits first then to U8 array +// Note that if expectedLen is shorter than actual value, the converted value is *not* +// equal to the original value! +// TODO optimization +func (bf *BinaryField[T]) ByteArrayValueOf(a frontend.Variable, expectedLen ...int) []U8 { + var opt bits.BaseConversionOption + if len(expectedLen) == 1 { + opt = bits.WithNbDigits(expectedLen[0] * 8) + } + + bits := bits.ToBinary(bf.api, a, opt) + lenBits := len(bits) + lenBytes := int(math.Ceil(float64(lenBits) / 8.0)) + + ret := make([]U8, lenBytes) + for i := 0; i < lenBytes; i++ { + b := bits[i*8] + for j := 1; j < 8; j++ { + v := bits[i*8+j] + v = bf.api.Mul(v, 1<> 11)}, ecc.BN254.ScalarField()) assert.NoError(err) } + +type byteArrayValueOfCircuit struct { + In frontend.Variable + Expected []U8 +} + +func (c *byteArrayValueOfCircuit) Define(api frontend.API) error { + uapi, err := New[U32](api) + if err != nil { + return err + } + + res := uapi.ByteArrayValueOf(c.In, 3) + api.AssertIsEqual(len(res), len(c.Expected)) + for i := 0; i < len(res); i++ { + uapi.ByteAssertEq(res[i], c.Expected[i]) + } + + return nil +} + +func TestByteArrayValueOf(t *testing.T) { + assert := test.NewAssert(t) + a, b, c := 13, 17, 19 + p := a + (b << 8) + (c << 16) + expected := NewU8Array([]uint8{uint8(a), uint8(b), uint8(c)}) + circuit := &byteArrayValueOfCircuit{ + Expected: expected, + } + assignment := &byteArrayValueOfCircuit{ + In: frontend.Variable(p), + Expected: expected, + } + + err := test.IsSolved(circuit, assignment, ecc.BN254.ScalarField()) + assert.NoError(err) +} From 982d654f734c80006ffeb67857ef8acbd830cb00 Mon Sep 17 00:00:00 2001 From: Weiji Guo Date: Thu, 29 Feb 2024 23:21:41 +0800 Subject: [PATCH 14/16] fixed array out of bound issue --- std/math/uints/uint8.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/math/uints/uint8.go b/std/math/uints/uint8.go index 406893294..980e0f91c 100644 --- a/std/math/uints/uint8.go +++ b/std/math/uints/uint8.go @@ -186,7 +186,7 @@ func (bf *BinaryField[T]) ByteArrayValueOf(a frontend.Variable, expectedLen ...i ret := make([]U8, lenBytes) for i := 0; i < lenBytes; i++ { b := bits[i*8] - for j := 1; j < 8; j++ { + for j := 1; j < 8 && i*8+j < lenBits; j++ { v := bits[i*8+j] v = bf.api.Mul(v, 1< Date: Fri, 1 Mar 2024 11:40:41 +0800 Subject: [PATCH 15/16] modify ByteArrayValueOf implementation and add testcase --- std/math/uints/uint8.go | 11 +++++--- std/math/uints/uint8_test.go | 51 ++++++++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/std/math/uints/uint8.go b/std/math/uints/uint8.go index 980e0f91c..d716eec67 100644 --- a/std/math/uints/uint8.go +++ b/std/math/uints/uint8.go @@ -175,19 +175,22 @@ func (bf *BinaryField[T]) ByteValueOf(a frontend.Variable) U8 { // TODO optimization func (bf *BinaryField[T]) ByteArrayValueOf(a frontend.Variable, expectedLen ...int) []U8 { var opt bits.BaseConversionOption + var bs []frontend.Variable if len(expectedLen) == 1 { opt = bits.WithNbDigits(expectedLen[0] * 8) + bs = bits.ToBinary(bf.api, a, opt) + } else { + bs = bits.ToBinary(bf.api, a) } - bits := bits.ToBinary(bf.api, a, opt) - lenBits := len(bits) + lenBits := len(bs) lenBytes := int(math.Ceil(float64(lenBits) / 8.0)) ret := make([]U8, lenBytes) for i := 0; i < lenBytes; i++ { - b := bits[i*8] + b := bs[i*8] for j := 1; j < 8 && i*8+j < lenBits; j++ { - v := bits[i*8+j] + v := bs[i*8+j] v = bf.api.Mul(v, 1< Date: Thu, 14 Mar 2024 22:30:13 +0800 Subject: [PATCH 16/16] reactivated Field.Cmp --- std/math/emulated/field_assert.go | 38 +++++++++++++++---------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index a2809e4eb..6b0ef990b 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -128,25 +128,25 @@ func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { return f.api.IsZero(limbSum) } -// // Cmp returns: -// // - -1 if a < b -// // - 0 if a = b -// // - 1 if a > b -// // -// // The method internally reduces the element and asserts that the value is less -// // than the modulus. -// func (f *Field[T]) Cmp(a, b *Element[T]) frontend.Variable { -// ca := f.Reduce(a) -// f.AssertIsInRange(ca) -// cb := f.Reduce(b) -// f.AssertIsInRange(cb) -// var res frontend.Variable = 0 -// for i := int(f.fParams.NbLimbs() - 1); i >= 0; i-- { -// lmbCmp := f.api.Cmp(ca.Limbs[i], cb.Limbs[i]) -// res = f.api.Select(f.api.IsZero(res), lmbCmp, res) -// } -// return res -// } +// Cmp returns: +// - -1 if a < b +// - 0 if a = b +// - 1 if a > b +// +// The method internally reduces the element and asserts that the value is less +// than the modulus. +func (f *Field[T]) Cmp(a, b *Element[T]) frontend.Variable { + ca := f.Reduce(a) + f.AssertIsInRange(ca) + cb := f.Reduce(b) + f.AssertIsInRange(cb) + var res frontend.Variable = 0 + for i := int(f.fParams.NbLimbs() - 1); i >= 0; i-- { + lmbCmp := f.api.Cmp(ca.Limbs[i], cb.Limbs[i]) + res = f.api.Select(f.api.IsZero(res), lmbCmp, res) + } + return res +} // TODO(@ivokub) // func (f *Field[T]) AssertIsDifferent(a, b *Element[T]) {