diff --git a/std/math/emulated/doc.go b/std/math/emulated/doc.go index 915afea48..48ff30a7f 100644 --- a/std/math/emulated/doc.go +++ b/std/math/emulated/doc.go @@ -111,6 +111,17 @@ identity at a random point: where e(X) is a polynomial used for carrying the overflows of the left- and right-hand side of the above equation. +This approach can be extended to the case when the left hand side is not a +simple multiplication, but rather any evaluation of a multivariate polynomial. +So in essence we can check the correctness of any polynomial evaluation modulo +r: + + F(x_1, x_2, ..., x_n) = c + z*r + +through the following identity: + + F(x_1(X), x_2(X), ..., x_n(X)) = c(X) + z(X) * r(X) + (2^w' - X) e(X). + # Subtraction We perform subtraction limb-wise between the elements x and y. However, we have diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index 264d0effc..37389ce15 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -1277,3 +1277,151 @@ func TestIsZeroEdgeCases(t *testing.T) { testIsZeroEdgeCases[BN254Fr](t) testIsZeroEdgeCases[emparams.Mod1e512](t) } + +type PolyEvalCircuit[T FieldParams] struct { + Inputs []Element[T] + TermsByIndices [][]int + Coeffs []int + Expected Element[T] +} + +func (c *PolyEvalCircuit[T]) Define(api frontend.API) error { + // withEval + f, err := NewField[T](api) + if err != nil { + return err + } + // reconstruct the terms from the inputs and the indices + terms := make([][]*Element[T], len(c.TermsByIndices)) + for i := range terms { + terms[i] = make([]*Element[T], len(c.TermsByIndices[i])) + for j := range terms[i] { + terms[i][j] = &c.Inputs[c.TermsByIndices[i][j]] + } + } + resEval := f.Eval(terms, c.Coeffs) + + // withSum + addTerms := make([]*Element[T], len(c.TermsByIndices)) + for i, term := range c.TermsByIndices { + termVal := f.One() + for j := range term { + termVal = f.Mul(termVal, &c.Inputs[term[j]]) + } + addTerms[i] = f.MulConst(termVal, big.NewInt(int64(c.Coeffs[i]))) + } + resSum := f.Sum(addTerms...) + + // mul no reduce + addTerms2 := make([]*Element[T], len(c.TermsByIndices)) + for i, term := range c.TermsByIndices { + termVal := f.One() + for j := range term { + termVal = f.MulNoReduce(termVal, &c.Inputs[term[j]]) + } + addTerms2[i] = f.MulConst(termVal, big.NewInt(int64(c.Coeffs[i]))) + } + resNoReduce := f.Sum(addTerms2...) + resReduced := f.Reduce(resNoReduce) + + // assertions + f.AssertIsEqual(resEval, &c.Expected) + f.AssertIsEqual(resSum, &c.Expected) + f.AssertIsEqual(resNoReduce, &c.Expected) + f.AssertIsEqual(resReduced, &c.Expected) + + return nil +} + +func TestPolyEval(t *testing.T) { + testPolyEval[Goldilocks](t) + testPolyEval[BN254Fr](t) + testPolyEval[emparams.Mod1e512](t) +} + +func testPolyEval[T FieldParams](t *testing.T) { + const nbInputs = 2 + assert := test.NewAssert(t) + var fp T + var err error + // 2*x^3 + 3*x^2 y + 4*x y^2 + 5*y^3 assuming we have inputs w=[x, y], then + // we can represent by the indices of the inputs: + // 2*x^3 + 3*x^2 y + 4*x y^2 + 5*y^3 -> 2*x*x*x + 3*x*x*y + 4*x*y*y + 5*y*y*y -> 2*w[0]*w[0]*w[0] + 3*w[0]*w[0]*w[1] + 4*w[0]*w[1]*w[1] + 5*w[1]*w[1]*w[1] + // the following variable gives the indices of the inputs. For givin the + // circuit this is better as then we can easily reference to the inputs by + // index. + toMulByIndex := [][]int{{0, 0, 0}, {0, 0, 1}, {0, 1, 1}, {1, 1, 1}} + coefficients := []int{2, 3, 4, 5} + inputs := make([]*big.Int, nbInputs) + assignmentInput := make([]Element[T], nbInputs) + for i := range inputs { + inputs[i], err = rand.Int(rand.Reader, fp.Modulus()) + assert.NoError(err) + } + for i := range inputs { + assignmentInput[i] = ValueOf[T](inputs[i]) + } + expected := new(big.Int) + for i, term := range toMulByIndex { + termVal := new(big.Int).SetInt64(int64(coefficients[i])) + for j := range term { + termVal.Mul(termVal, inputs[term[j]]) + } + expected.Add(expected, termVal) + } + expected.Mod(expected, fp.Modulus()) + + assignment := &PolyEvalCircuit[T]{ + Inputs: assignmentInput, + Expected: ValueOf[T](expected), + } + assert.CheckCircuit(&PolyEvalCircuit[T]{Inputs: make([]Element[T], nbInputs), TermsByIndices: toMulByIndex, Coeffs: coefficients}, test.WithValidAssignment(assignment)) +} + +type PolyEvalNegativeCoefficient[T FieldParams] struct { + Inputs []Element[T] + Res Element[T] +} + +func (c *PolyEvalNegativeCoefficient[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + // x - y + coefficients := []int{1, -1} + res := f.Eval([][]*Element[T]{{&c.Inputs[0]}, {&c.Inputs[1]}}, coefficients) + f.AssertIsEqual(res, &c.Res) + return nil +} + +func TestPolyEvalNegativeCoefficient(t *testing.T) { + testPolyEvalNegativeCoefficient[Goldilocks](t) + testPolyEvalNegativeCoefficient[BN254Fr](t) + testPolyEvalNegativeCoefficient[emparams.Mod1e512](t) +} + +func testPolyEvalNegativeCoefficient[T FieldParams](t *testing.T) { + t.Skip("not implemented yet") + assert := test.NewAssert(t) + var fp T + fmt.Println("modulus", fp.Modulus()) + var err error + const nbInputs = 2 + inputs := make([]*big.Int, nbInputs) + assignmentInput := make([]Element[T], nbInputs) + for i := range inputs { + inputs[i], err = rand.Int(rand.Reader, fp.Modulus()) + assert.NoError(err) + } + for i := range inputs { + fmt.Println("input", i, inputs[i]) + assignmentInput[i] = ValueOf[T](inputs[i]) + } + expected := new(big.Int).Sub(inputs[0], inputs[1]) + expected.Mod(expected, fp.Modulus()) + fmt.Println("expected", expected) + assignment := &PolyEvalNegativeCoefficient[T]{Inputs: assignmentInput, Res: ValueOf[T](expected)} + err = test.IsSolved(&PolyEvalNegativeCoefficient[T]{Inputs: make([]Element[T], nbInputs)}, assignment, testCurve.ScalarField()) + assert.NoError(err) +} diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index dce4074fa..d7aae8714 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -47,7 +47,7 @@ type Field[T FieldParams] struct { constrainedLimbs map[[16]byte]struct{} checker frontend.Rangechecker - mulChecks []mulCheck[T] + deferredChecks []deferredChecker } type ctxKey[T FieldParams] struct{} @@ -103,7 +103,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb()) } - native.Compiler().Defer(f.performMulChecks) + native.Compiler().Defer(f.performDeferredChecks) if storer, ok := native.(kvstore.Store); ok { storer.SetKeyValue(ctxKey[T]{}, f) } @@ -282,3 +282,15 @@ func max[T constraints.Ordered](a ...T) T { } return m } + +func sum[T constraints.Ordered](a ...T) T { + if len(a) == 0 { + var f T + return f + } + m := a[0] + for _, v := range a[1:] { + m += v + } + return m +} diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 78785d82b..d6e3bfdfb 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -4,12 +4,59 @@ import ( "fmt" "math/big" "math/bits" + "slices" "github.com/consensys/gnark/frontend" limbs "github.com/consensys/gnark/std/internal/limbcomposition" "github.com/consensys/gnark/std/multicommit" ) +// deferredChecker is an interface for deferring a check in non-native +// arithmetic. The idea of the deferred check is that we do not compute the +// check immediately, but we store the values and the check to be done later. +// This allows us to share the verifier challenge computation between multiple +// checks. +// +// Currently used for multiplication and multivariate evaluation checks. +type deferredChecker interface { + // toCommit outputs the variable which should be committed to. The checker + // then uses the commitment to obtain the verifier challenge for the + // Schwartz-Zippel lemma. + toCommit() []frontend.Variable + // maxLen returns the maximum number of limbs in the deferred check. This is + // used for computing the number of powers of the verifier challenge to + // compute + maxLen() int + + // evalRound1 evaluates the first round of the check at with the random + // challenge, given through its powers at. In the first round we do not + // assume that any of the values is already evaluated as they come directly + // from hint. + // + // The method should store the evaluation result inside the Element and mark + // it as evaluated. If the method is called for already evaluated input then + // should assume that the challenge is the same as the one used for the + // evaluation. + evalRound1(at []frontend.Variable) + // evalRound2 evaluates the second round of the check at a given random point + // at[0]. However, it may happen that some of the values are equal to the + // result from a previous check. In that case we can reuse the evaluation to + // save constraints. + // + // The method should store the evaluation result inside the Element and mark + // it as evaluated. If the method is called for already evaluated input then + // should assume that the challenge is the same as the one used for the + // evaluation. + evalRound2(at []frontend.Variable) + // check checks the correctness of the deferred check. The method should use + // the stored evaluations. We additionally provide the evaluation of + // p(challenge) and (2^t-challenge) as they are static over all checks. + check(api frontend.API, peval frontend.Variable, coef frontend.Variable) + // cleanEvaluations cleans the cached evaluation values. This is necessary for + // ensuring the circuit stability over many compilations. + cleanEvaluations() +} + // mulCheck represents a single multiplication check. Instead of doing a // multiplication exactly where called, we compute the result using hint and // return it. Additionally, we store the correctness check for later checking @@ -62,6 +109,35 @@ type mulCheck[T FieldParams] struct { p *Element[T] // modulus if non-nil } +func (mc *mulCheck[T]) toCommit() []frontend.Variable { + nbToCommit := len(mc.a.Limbs) + len(mc.b.Limbs) + len(mc.r.Limbs) + len(mc.k.Limbs) + len(mc.c.Limbs) + if mc.p != nil { + nbToCommit += len(mc.p.Limbs) + } + toCommit := make([]frontend.Variable, 0, nbToCommit) + toCommit = append(toCommit, mc.a.Limbs...) + toCommit = append(toCommit, mc.b.Limbs...) + toCommit = append(toCommit, mc.r.Limbs...) + toCommit = append(toCommit, mc.k.Limbs...) + toCommit = append(toCommit, mc.c.Limbs...) + if mc.p != nil { + toCommit = append(toCommit, mc.p.Limbs...) + } + return toCommit +} + +func (mc *mulCheck[T]) maxLen() int { + maxLen := len(mc.a.Limbs) + maxLen = max(maxLen, len(mc.b.Limbs)) + maxLen = max(maxLen, len(mc.r.Limbs)) + maxLen = max(maxLen, len(mc.k.Limbs)) + maxLen = max(maxLen, len(mc.c.Limbs)) + if mc.p != nil { + maxLen = max(maxLen, len(mc.p.Limbs)) + } + return maxLen +} + // evalRound1 evaluates first c(X), r(X) and k(X) at a given random point at[0]. // In the first round we do not assume that any of them is already evaluated as // they come directly from hint. @@ -132,7 +208,7 @@ func (f *Field[T]) mulMod(a, b *Element[T], _ uint, p *Element[T]) *Element[T] { r: r, p: p, } - f.mulChecks = append(f.mulChecks, mc) + f.deferredChecks = append(f.deferredChecks, &mc) return r } @@ -156,7 +232,7 @@ func (f *Field[T]) checkZero(a *Element[T], p *Element[T]) { r: r, // expected to be zero on zero limbs. p: p, } - f.mulChecks = append(f.mulChecks, mc) + f.deferredChecks = append(f.deferredChecks, &mc) } // evalWithChallenge represents element a as a polynomial a(X) and evaluates at @@ -184,12 +260,12 @@ func (f *Field[T]) evalWithChallenge(a *Element[T], at []frontend.Variable) *Ele // performMulChecks should be deferred to actually perform all the // multiplication checks. -func (f *Field[T]) performMulChecks(api frontend.API) error { +func (f *Field[T]) performDeferredChecks(api frontend.API) error { // use given api. We are in defer and API may be different to what we have // stored. // there are no multiplication checks, nothing to do - if len(f.mulChecks) == 0 { + if len(f.deferredChecks) == 0 { return nil } @@ -201,23 +277,15 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { // multi-commit and range checks are in different commitment, then we have // problem. var toCommit []frontend.Variable - for i := range f.mulChecks { - toCommit = append(toCommit, f.mulChecks[i].a.Limbs...) - toCommit = append(toCommit, f.mulChecks[i].b.Limbs...) - toCommit = append(toCommit, f.mulChecks[i].r.Limbs...) - toCommit = append(toCommit, f.mulChecks[i].k.Limbs...) - toCommit = append(toCommit, f.mulChecks[i].c.Limbs...) - if f.mulChecks[i].p != nil { - toCommit = append(toCommit, f.mulChecks[i].p.Limbs...) - } + for i := range f.deferredChecks { + toCommit = append(toCommit, f.deferredChecks[i].toCommit()...) } // we give all the inputs as inputs to obtain random verifier challenge. multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { // for efficiency, we compute all powers of the challenge as slice at. coefsLen := int(f.fParams.NbLimbs()) - for i := range f.mulChecks { - coefsLen = max(coefsLen, len(f.mulChecks[i].a.Limbs), len(f.mulChecks[i].b.Limbs), - len(f.mulChecks[i].c.Limbs), len(f.mulChecks[i].k.Limbs)) + for i := range f.deferredChecks { + coefsLen = max(coefsLen, f.deferredChecks[i].maxLen()) } at := make([]frontend.Variable, coefsLen) at[0] = commitment @@ -225,12 +293,12 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { at[i] = api.Mul(at[i-1], commitment) } // evaluate all r, k, c - for i := range f.mulChecks { - f.mulChecks[i].evalRound1(at) + for i := range f.deferredChecks { + f.deferredChecks[i].evalRound1(at) } // assuming r is input to some other multiplication, then is already evaluated - for i := range f.mulChecks { - f.mulChecks[i].evalRound2(at) + for i := range f.deferredChecks { + f.deferredChecks[i].evalRound2(at) } // evaluate p(X) at challenge pval := f.evalWithChallenge(f.Modulus(), at) @@ -239,13 +307,13 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { coef.Lsh(coef, f.fParams.BitsPerLimb()) ccoef := api.Sub(coef, commitment) // verify all mulchecks - for i := range f.mulChecks { - f.mulChecks[i].check(api, pval.evaluation, ccoef) + for i := range f.deferredChecks { + f.deferredChecks[i].check(api, pval.evaluation, ccoef) } // clean cached evaluation. Helps in case we compile the same circuit // multiple times. - for i := range f.mulChecks { - f.mulChecks[i].cleanEvaluations() + for i := range f.deferredChecks { + f.deferredChecks[i].cleanEvaluations() } return nil }, toCommit...) @@ -287,16 +355,12 @@ func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool, customMod *Eleme nbCarryLimbs := max(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)), nbMultiplicationResLimbs(int(nbQuoLimbs), int(nbLimbs))) - 1 // we encode the computed parameters and widths to the hint function so can // know how many limbs to expect. - hintInputs := []frontend.Variable{ - nbBits, - nbLimbs, - len(a.Limbs), - nbQuoLimbs, - } modulusLimbs := f.Modulus().Limbs if customMod != nil { modulusLimbs = customMod.Limbs } + hintInputs := make([]frontend.Variable, 0, 4+len(modulusLimbs)+len(a.Limbs)+len(b.Limbs)) + hintInputs = append(hintInputs, nbBits, nbLimbs, len(a.Limbs), nbQuoLimbs) hintInputs = append(hintInputs, modulusLimbs...) hintInputs = append(hintInputs, a.Limbs...) hintInputs = append(hintInputs, b.Limbs...) @@ -367,30 +431,8 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { if err := limbs.Decompose(rem, uint(nbBits), remLimbs); err != nil { return fmt.Errorf("decompose rem: %w", err) } - xp := make([]*big.Int, nbMultiplicationResLimbs(nbALen, nbBLen)) - yp := make([]*big.Int, nbMultiplicationResLimbs(nbQuoLen, nbLimbs)) - for i := range xp { - xp[i] = new(big.Int) - } - for i := range yp { - yp[i] = new(big.Int) - } - tmp := new(big.Int) - // we know compute the schoolbook multiprecision multiplication of a*b and - // r+k*p - for i := 0; i < nbALen; i++ { - for j := 0; j < nbBLen; j++ { - tmp.Mul(alimbs[i], blimbs[j]) - xp[i+j].Add(xp[i+j], tmp) - } - } - for i := 0; i < nbLimbs; i++ { - yp[i].Add(yp[i], remLimbs[i]) - for j := 0; j < nbQuoLen; j++ { - tmp.Mul(quoLimbs[j], plimbs[i]) - yp[i+j].Add(yp[i+j], tmp) - } - } + xp := limbMul(alimbs, blimbs) + yp := limbMul(quoLimbs, plimbs) carry := new(big.Int) for i := range carryLimbs { if i < len(xp) { @@ -510,3 +552,455 @@ func (f *Field[T]) Exp(base, exp *Element[T]) *Element[T] { res = f.Select(expBts[n-1], f.Mul(base, res), res) return res } + +// multivariate represents a multivariate polynomial. It is a list of terms +// where each term is a list of exponents for each variable. The coefficients +// are stored in the same order as the terms. +// +// For example, if there are two inputs x and y and we compute the polynomial +// +// x^2 + 2xy + y^2 +// +// then we have the terms +// +// [[2, 0], [1, 1], [0, 2]] +// +// and coefficients +// +// [1, 1, 1]. +// +// These definitions differ from how we expose the method in the [Field.Eval] +// method - there as we use pointers to the variables themselves, then we can +// allow to give the inputs directly a la +// +// f.Eval([][]*Element[T]{{x,x}, {x,y}, {y,y}}, []int{1, 1, 1}), +// +// but we cannot use the references inside the hint function as we work with +// solved values. +type multivariate[T FieldParams] struct { + Terms [][]int + Coefficients []int +} + +// Eval evaluates the multivariate polynomial. The elements of the inner slices +// are multiplied together and then summed together with the corresponding +// coefficient. +// +// NB! This is experimental API. It does not support negative coefficients. It +// does not check that computing the term wouldn't overflow the field. +// +// For example, for computing the expression x^2 + 2xy + y^2 we would call +// +// f.Eval([][]*Element[T]{{x,x}, {x,y}, {y,y}}, []int{1, 2, 1}) +// +// The method returns the result of the evaluation. +// +// To overcome the problem of not supporting negative coefficients, we can use a +// constant non-native element -1 as one of the inputs. +func (f *Field[T]) Eval(at [][]*Element[T], coefs []int) *Element[T] { + if len(at) != len(coefs) { + panic("terms and coefficients mismatch") + } + // it is the obvious case - when we don't have any inputs then we need to + // evaluate the zero polynomial which is always zero. + if len(at) == 0 { + return f.Zero() + } + // omit the negative coefficients for now. We don't support it for now. + for i := range coefs { + if coefs[i] < 0 { + panic("negative coefficient") + } + } + // initialize the multivariate struct from the inputs. The current method + // takes as input references to the elements. However, the hint function + // works with solved values. So it would be better to work with the exact + // exponents there. + + // we detect all different elements in the inputs. + // + // it would be easier to use a map to store the elements and then use the + // map to get the inputs in the right order. However, for deterministic + // circuit compilation we need to use the same order of inputs. So we use + // slice instead. + var allElems []*Element[T] + for i := range at { + AT_INNER: + for j := range at[i] { + for k := range allElems { + if allElems[k] == at[i][j] { + continue AT_INNER + } + } + allElems = append(allElems, at[i][j]) + } + } + // we already know all different inputs. We now count the number of + // occurrences in every term. + terms := make([][]int, 0, len(at)) + for i := range at { + term := make([]int, len(allElems)) + for j := range at[i] { + term[slices.Index(allElems, at[i][j])]++ + } + terms = append(terms, term) + } + + // ensure that all the elements have the range checks enforced on limbs. + // Necessary in case the input is a witness. + for i := range allElems { + f.enforceWidthConditional(allElems[i]) + } + + // multivariate is used for passing the terms and coefficients to the hint + // in a compact form. + mv := &multivariate[T]{ + Terms: terms, + Coefficients: coefs, + } + + // we call the hint to compute the result. The hint returns the reduced + // result, the quotient and the carries. + k, r, c, err := f.callPolyMvHint(mv, allElems) + if err != nil { + panic(err) + } + + // finally, we store the deferred check which is performed later. The + // `mvCheck` implements the deferredChecker interface, so that we use the + // generic deferred check method. + mvc := mvCheck[T]{ + f: f, + mv: mv, + vals: allElems, + r: r, + k: k, + c: c, + } + + f.deferredChecks = append(f.deferredChecks, &mvc) + return r +} + +// callPolyMvHint computes the multivariate evaluation given by mv at at. It +// returns the remainder (reduced result), the quotient and the carries. The +// computation is performed inside a hint, so it is the callers responsibility to +// perform the deferred multiplication check. +func (f *Field[T]) callPolyMvHint(mv *multivariate[T], at []*Element[T]) (quo, rem, carries *Element[T], err error) { + // first compute the length of the result so that we know how many bits we need for the quotient. + nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb() + modBits := uint(f.fParams.Modulus().BitLen()) + quoSize := f.polyMvEvalQuoSize(mv, at) + nbQuoLimbs := (uint(quoSize) - modBits + nbBits) / nbBits + nbRemLimbs := nbLimbs + nbCarryLimbs := nbMultiplicationResLimbs(int(nbQuoLimbs), int(nbLimbs)) - 1 + + nbHintInputs := 7 + len(at)*len(mv.Terms) + len(mv.Coefficients) + len(f.Modulus().Limbs) + for i := range at { + nbHintInputs += len(at[i].Limbs) + 1 + } + hintInputs := make([]frontend.Variable, 0, nbHintInputs) + hintInputs = append(hintInputs, nbBits, nbLimbs, len(mv.Terms), len(at), nbQuoLimbs, nbRemLimbs, nbCarryLimbs) + // store the terms in the hint input. First the exponents + for i := range mv.Terms { + for j := range mv.Terms[i] { + hintInputs = append(hintInputs, mv.Terms[i][j]) + } + } + // and now the coefficients + for i := range mv.Coefficients { + hintInputs = append(hintInputs, mv.Coefficients[i]) + } + // finally, we store the modulus and all the inputs + hintInputs = append(hintInputs, f.Modulus().Limbs...) + for i := range at { + // keep in mind that not all inputs may be full length. We need to store + // the length also. + hintInputs = append(hintInputs, len(at[i].Limbs)) + hintInputs = append(hintInputs, at[i].Limbs...) + } + ret, err := f.api.NewHint(polyMvHint, int(nbQuoLimbs)+int(nbRemLimbs)+int(nbCarryLimbs), hintInputs...) + if err != nil { + err = fmt.Errorf("call hint: %w", err) + return + } + quo = f.packLimbs(ret[:nbQuoLimbs], false) + rem = f.packLimbs(ret[nbQuoLimbs:nbQuoLimbs+nbRemLimbs], true) + carries = f.newInternalElement(ret[nbQuoLimbs+nbRemLimbs:], 0) + return quo, rem, carries, nil +} + +// mvCheck is a deferred check for multivariate polynomial evaluation. It +// contains the multivariate polynomial, the values at which it is evaluated and +// the reduced result, quotient and carries. Implements deferredChecker and +// follows mulCheck implementation. +type mvCheck[T FieldParams] struct { + f *Field[T] + mv *multivariate[T] + vals []*Element[T] + r *Element[T] // reduced result + k *Element[T] // quotient + c *Element[T] // carry +} + +func (mc *mvCheck[T]) toCommit() []frontend.Variable { + nbToCommit := len(mc.r.Limbs) + len(mc.k.Limbs) + len(mc.c.Limbs) + for j := range mc.vals { + nbToCommit += len(mc.vals[j].Limbs) + } + toCommit := make([]frontend.Variable, 0, nbToCommit) + toCommit = append(toCommit, mc.r.Limbs...) + toCommit = append(toCommit, mc.k.Limbs...) + toCommit = append(toCommit, mc.c.Limbs...) + for j := range mc.vals { + toCommit = append(toCommit, mc.vals[j].Limbs...) + } + return toCommit +} + +func (mc *mvCheck[T]) maxLen() int { + maxLen := len(mc.r.Limbs) + maxLen = max(maxLen, len(mc.k.Limbs)) + maxLen = max(maxLen, len(mc.c.Limbs)) + for j := range mc.vals { + maxLen = max(maxLen, len(mc.vals[j].Limbs)) + } + return maxLen +} + +func (mc *mvCheck[T]) evalRound1(at []frontend.Variable) { + mc.c = mc.f.evalWithChallenge(mc.c, at) + mc.r = mc.f.evalWithChallenge(mc.r, at) + mc.k = mc.f.evalWithChallenge(mc.k, at) +} + +func (mc *mvCheck[T]) evalRound2(at []frontend.Variable) { + for i := range mc.vals { + mc.vals[i] = mc.f.evalWithChallenge(mc.vals[i], at) + } +} + +func (mc *mvCheck[T]) check(api frontend.API, peval, coef frontend.Variable) { + ls := frontend.Variable(0) + for i, term := range mc.mv.Terms { + termProd := frontend.Variable(mc.mv.Coefficients[i]) + for i, pow := range term { + for j := 0; j < pow; j++ { + termProd = api.Mul(termProd, mc.vals[i].evaluation) + } + } + ls = api.Add(ls, termProd) + } + rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) + api.AssertIsEqual(ls, rs) +} + +func (mc *mvCheck[T]) cleanEvaluations() { + for i := range mc.vals { + mc.vals[i].evaluation = 0 + mc.vals[i].isEvaluated = false + } + mc.r.evaluation = 0 + mc.r.isEvaluated = false + mc.k.evaluation = 0 + mc.k.isEvaluated = false + mc.c.evaluation = 0 + mc.c.isEvaluated = false +} + +// polyMvEvalQuoSize compute the length of the quotient in bits when evaluating +// the multivariate polynomial. The method is used to compute the number of bits +// required to represent the quotient in the hint function. +// +// As it only depends on the bit-length of the inputs, then we can precompute it +// regardless of the actual values. +func (f *Field[T]) polyMvEvalQuoSize(mv *multivariate[T], at []*Element[T]) (quoSize int) { + modBits := f.fParams.Modulus().BitLen() + quoSizes := make([]int, len(mv.Terms)) + for i, term := range mv.Terms { + // for every term, the result length is the sum of the lengths of the + // variables and the coefficient. + var lengths []int + for j, pow := range term { + for k := 0; k < pow; k++ { + lengths = append(lengths, modBits+int(at[j].overflow)) + } + } + lengths = append(lengths, bits.Len(uint(mv.Coefficients[i]))) + quoSizes[i] = sum(lengths...) + } + // and for the full result, it is maximum of the inputs. We also add a bit + // for every term for overflow. + quoSize = max(quoSizes...) + len(quoSizes) + return quoSize +} + +// polyMvHint computes the multivariate evaluation as a hint. Should not be +// called directly, but rather through [Field.callPolyMvHint] method which +// handles the input packing and output unpacking. +func polyMvHint(mod *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) < 7 { + return fmt.Errorf("not enough inputs") + } + var ( + nbBits = int(inputs[0].Int64()) + nbLimbs = int(inputs[1].Int64()) + nbTerms = int(inputs[2].Int64()) + nbVars = int(inputs[3].Int64()) + nbQuoLimbs = int(inputs[4].Int64()) + nbRemLimbs = int(inputs[5].Int64()) + nbCarryLimbs = int(inputs[6].Int64()) + ) + if len(outputs) != nbQuoLimbs+nbRemLimbs+nbCarryLimbs { + return fmt.Errorf("output length mismatch") + } + outPtr := 0 + quoLimbs := outputs[outPtr : outPtr+nbQuoLimbs] + outPtr += nbQuoLimbs + remLimbs := outputs[outPtr : outPtr+nbRemLimbs] + outPtr += nbRemLimbs + carryLimbs := outputs[outPtr : outPtr+nbCarryLimbs] + terms := make([][]int, nbTerms) + ptr := 7 + // read the terms + for i := range terms { + terms[i] = make([]int, nbVars) + for j := range terms[i] { + terms[i][j] = int(inputs[ptr].Int64()) + ptr++ + } + } + // read the coefficients + coeffs := make([]*big.Int, nbTerms) + for i := range coeffs { + coeffs[i] = inputs[ptr] + ptr++ + } + // read the modulus + plimbs := inputs[ptr : ptr+nbLimbs] + ptr += nbLimbs + p := new(big.Int) + if err := limbs.Recompose(plimbs, uint(nbBits), p); err != nil { + return fmt.Errorf("recompose p: %w", err) + } + // read the inputs + varsLimbs := make([][]*big.Int, nbVars) + for i := range varsLimbs { + varsLimbs[i] = make([]*big.Int, int(inputs[ptr].Int64())) + ptr++ + for j := range varsLimbs[i] { + varsLimbs[i][j] = inputs[ptr] + ptr++ + } + } + if ptr != len(inputs) { + return fmt.Errorf("inputs not exhausted") + } + // recompose the inputs in limb-form to *big.Int form + vars := make([]*big.Int, nbVars) + for i := range vars { + vars[i] = new(big.Int) + if err := limbs.Recompose(varsLimbs[i], uint(nbBits), vars[i]); err != nil { + return fmt.Errorf("recompose vars[%d]: %w", i, err) + } + } + + // compute the result on full inputs + fullLhs := new(big.Int) + for i, term := range terms { + termRes := new(big.Int).Set(coeffs[i]) + for i, pow := range term { + for j := 0; j < pow; j++ { + termRes.Mul(termRes, vars[i]) + } + } + fullLhs.Add(fullLhs, termRes) + } + + // compute the result as r + k*p for now + var ( + quo = new(big.Int) + rem = new(big.Int) + ) + if p.Cmp(new(big.Int)) != 0 { + quo.QuoRem(fullLhs, p, rem) + } + // write the remainder and quotient to output + if err := limbs.Decompose(quo, uint(nbBits), quoLimbs); err != nil { + return fmt.Errorf("decompose quo: %w", err) + } + if err := limbs.Decompose(rem, uint(nbBits), remLimbs); err != nil { + return fmt.Errorf("decompose rem: %w", err) + } + + // compute the result on limbs + tmp := new(big.Int) + var lhs []*big.Int + for i, term := range terms { + // collect the variables to be multiplied together + var termVarLimbs [][]*big.Int + for i, pow := range term { + for j := 0; j < pow; j++ { + termVarLimbs = append(termVarLimbs, varsLimbs[i]) + } + } + if len(termVarLimbs) == 0 { + continue + } + termRes := []*big.Int{new(big.Int).Set(coeffs[i])} + // perform limbwise multiplication + for _, toMul := range termVarLimbs { + termRes = limbMul(termRes, toMul) + } + // add current term to the result. Increase the length of necessary when + // required. + for i := len(lhs); i < len(termRes); i++ { + lhs = append(lhs, new(big.Int)) + } + for i := range termRes { + lhs[i].Add(lhs[i], termRes[i]) + } + } + + // compute the result as r + k*p on limbs + rhs := make([]*big.Int, nbMultiplicationResLimbs(nbQuoLimbs, nbLimbs)) + for i := range rhs { + rhs[i] = new(big.Int) + } + for i := 0; i < nbLimbs; i++ { + rhs[i].Add(rhs[i], remLimbs[i]) + for j := 0; j < nbQuoLimbs; j++ { + tmp.Mul(quoLimbs[j], plimbs[i]) + rhs[i+j].Add(rhs[i+j], tmp) + } + } + + // compute the carries + carry := new(big.Int) + for i := range carryLimbs { + if i < len(lhs) { + carry.Add(carry, lhs[i]) + } + if i < len(rhs) { + carry.Sub(carry, rhs[i]) + } + carry.Rsh(carry, uint(nbBits)) + carryLimbs[i] = new(big.Int).Set(carry) + } + + return nil +} + +func limbMul(lhs []*big.Int, rhs []*big.Int) []*big.Int { + tmp := new(big.Int) + res := make([]*big.Int, nbMultiplicationResLimbs(len(lhs), len(rhs))) + for i := range res { + res[i] = new(big.Int) + } + for i := 0; i < len(lhs); i++ { + for j := 0; j < len(rhs); j++ { + res[i+j].Add(res[i+j], tmp.Mul(lhs[i], rhs[j])) + } + } + return res +} diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 7ebdba29e..28a84b9a8 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -24,6 +24,7 @@ func GetHints() []solver.Hint { SqrtHint, mulHint, subPaddingHint, + polyMvHint, } }