Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: direct multivariate polynomial evaluation in non-native #1299

Merged
merged 14 commits into from
Nov 22, 2024
11 changes: 11 additions & 0 deletions std/math/emulated/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ identity at a random point:
where e(X) is a polynomial used for carrying the overflows of the left- and
right-hand side of the above equation.

This approach can be extended to the case when the left hand side is not a
simple multiplication, but rather any evaluation of a multivariate polynomial.
So in essence we can check the correctness of any polynomial evaluation modulo
r:

F(x_1, x_2, ..., x_n) = c + z*r

through the following identity:

F(x_1(X), x_2(X), ..., x_n(X)) = c(X) + z(X) * r(X) + (2^w' - X) e(X).

# Subtraction

We perform subtraction limb-wise between the elements x and y. However, we have
Expand Down
148 changes: 148 additions & 0 deletions std/math/emulated/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1277,3 +1277,151 @@ func TestIsZeroEdgeCases(t *testing.T) {
testIsZeroEdgeCases[BN254Fr](t)
testIsZeroEdgeCases[emparams.Mod1e512](t)
}

type PolyEvalCircuit[T FieldParams] struct {
Inputs []Element[T]
TermsByIndices [][]int
Coeffs []int
Expected Element[T]
}

func (c *PolyEvalCircuit[T]) Define(api frontend.API) error {
// withEval
f, err := NewField[T](api)
if err != nil {
return err
}
// reconstruct the terms from the inputs and the indices
terms := make([][]*Element[T], len(c.TermsByIndices))
for i := range terms {
terms[i] = make([]*Element[T], len(c.TermsByIndices[i]))
for j := range terms[i] {
terms[i][j] = &c.Inputs[c.TermsByIndices[i][j]]
}
}
resEval := f.Eval(terms, c.Coeffs)

// withSum
addTerms := make([]*Element[T], len(c.TermsByIndices))
for i, term := range c.TermsByIndices {
termVal := f.One()
for j := range term {
termVal = f.Mul(termVal, &c.Inputs[term[j]])
}
addTerms[i] = f.MulConst(termVal, big.NewInt(int64(c.Coeffs[i])))
}
resSum := f.Sum(addTerms...)

// mul no reduce
addTerms2 := make([]*Element[T], len(c.TermsByIndices))
for i, term := range c.TermsByIndices {
termVal := f.One()
for j := range term {
termVal = f.MulNoReduce(termVal, &c.Inputs[term[j]])
}
addTerms2[i] = f.MulConst(termVal, big.NewInt(int64(c.Coeffs[i])))
}
resNoReduce := f.Sum(addTerms2...)
resReduced := f.Reduce(resNoReduce)

// assertions
f.AssertIsEqual(resEval, &c.Expected)
f.AssertIsEqual(resSum, &c.Expected)
f.AssertIsEqual(resNoReduce, &c.Expected)
f.AssertIsEqual(resReduced, &c.Expected)

return nil
}

func TestPolyEval(t *testing.T) {
testPolyEval[Goldilocks](t)
testPolyEval[BN254Fr](t)
testPolyEval[emparams.Mod1e512](t)
}

func testPolyEval[T FieldParams](t *testing.T) {
const nbInputs = 2
assert := test.NewAssert(t)
var fp T
var err error
// 2*x^3 + 3*x^2 y + 4*x y^2 + 5*y^3 assuming we have inputs w=[x, y], then
// we can represent by the indices of the inputs:
// 2*x^3 + 3*x^2 y + 4*x y^2 + 5*y^3 -> 2*x*x*x + 3*x*x*y + 4*x*y*y + 5*y*y*y -> 2*w[0]*w[0]*w[0] + 3*w[0]*w[0]*w[1] + 4*w[0]*w[1]*w[1] + 5*w[1]*w[1]*w[1]
// the following variable gives the indices of the inputs. For givin the
// circuit this is better as then we can easily reference to the inputs by
// index.
toMulByIndex := [][]int{{0, 0, 0}, {0, 0, 1}, {0, 1, 1}, {1, 1, 1}}
coefficients := []int{2, 3, 4, 5}
inputs := make([]*big.Int, nbInputs)
assignmentInput := make([]Element[T], nbInputs)
for i := range inputs {
inputs[i], err = rand.Int(rand.Reader, fp.Modulus())
assert.NoError(err)
}
for i := range inputs {
assignmentInput[i] = ValueOf[T](inputs[i])
}
expected := new(big.Int)
for i, term := range toMulByIndex {
termVal := new(big.Int).SetInt64(int64(coefficients[i]))
for j := range term {
termVal.Mul(termVal, inputs[term[j]])
}
expected.Add(expected, termVal)
}
expected.Mod(expected, fp.Modulus())

assignment := &PolyEvalCircuit[T]{
Inputs: assignmentInput,
Expected: ValueOf[T](expected),
}
assert.CheckCircuit(&PolyEvalCircuit[T]{Inputs: make([]Element[T], nbInputs), TermsByIndices: toMulByIndex, Coeffs: coefficients}, test.WithValidAssignment(assignment))
}

type PolyEvalNegativeCoefficient[T FieldParams] struct {
Inputs []Element[T]
Res Element[T]
}

func (c *PolyEvalNegativeCoefficient[T]) Define(api frontend.API) error {
f, err := NewField[T](api)
if err != nil {
return err
}
// x - y
coefficients := []int{1, -1}
res := f.Eval([][]*Element[T]{{&c.Inputs[0]}, {&c.Inputs[1]}}, coefficients)
f.AssertIsEqual(res, &c.Res)
return nil
}

func TestPolyEvalNegativeCoefficient(t *testing.T) {
testPolyEvalNegativeCoefficient[Goldilocks](t)
testPolyEvalNegativeCoefficient[BN254Fr](t)
testPolyEvalNegativeCoefficient[emparams.Mod1e512](t)
}

func testPolyEvalNegativeCoefficient[T FieldParams](t *testing.T) {
t.Skip("not implemented yet")
assert := test.NewAssert(t)
var fp T
fmt.Println("modulus", fp.Modulus())
var err error
const nbInputs = 2
inputs := make([]*big.Int, nbInputs)
assignmentInput := make([]Element[T], nbInputs)
for i := range inputs {
inputs[i], err = rand.Int(rand.Reader, fp.Modulus())
assert.NoError(err)
}
for i := range inputs {
fmt.Println("input", i, inputs[i])
assignmentInput[i] = ValueOf[T](inputs[i])
}
expected := new(big.Int).Sub(inputs[0], inputs[1])
expected.Mod(expected, fp.Modulus())
fmt.Println("expected", expected)
assignment := &PolyEvalNegativeCoefficient[T]{Inputs: assignmentInput, Res: ValueOf[T](expected)}
err = test.IsSolved(&PolyEvalNegativeCoefficient[T]{Inputs: make([]Element[T], nbInputs)}, assignment, testCurve.ScalarField())
assert.NoError(err)
}
16 changes: 14 additions & 2 deletions std/math/emulated/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Field[T FieldParams] struct {
constrainedLimbs map[[16]byte]struct{}
checker frontend.Rangechecker

mulChecks []mulCheck[T]
deferredChecks []deferredChecker
}

type ctxKey[T FieldParams] struct{}
Expand Down Expand Up @@ -103,7 +103,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb())
}

native.Compiler().Defer(f.performMulChecks)
native.Compiler().Defer(f.performDeferredChecks)
if storer, ok := native.(kvstore.Store); ok {
storer.SetKeyValue(ctxKey[T]{}, f)
}
Expand Down Expand Up @@ -282,3 +282,15 @@ func max[T constraints.Ordered](a ...T) T {
}
return m
}

func sum[T constraints.Ordered](a ...T) T {
if len(a) == 0 {
var f T
return f
}
m := a[0]
for _, v := range a[1:] {
m += v
}
return m
}
Loading