Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace package smt with https://github.com/mdehoog/gnark-circom-smt/ #4

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions emulated/ecdsa/address_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/consensys/gnark/test"
"github.com/ethereum/go-ethereum/crypto"
qt "github.com/frankban/quicktest"
internaltest "github.com/vocdoni/gnark-crypto-primitives/test"
"github.com/vocdoni/gnark-crypto-primitives/testutil"
)

type testAddressCircuit struct {
Expand Down Expand Up @@ -54,7 +54,7 @@ func TestAddressDerivation(t *testing.T) {
fmt.Println("constrains", p.NbConstraints())
// hash a test message and sign it
input := crypto.Keccak256Hash([]byte("hello")).Bytes()
testSig, err := internaltest.GenerateAccountAndSign(input)
testSig, err := testutil.GenerateAccountAndSign(input)
c.Assert(err, qt.IsNil)
addrLE := new(big.Int).SetBytes(goSwapEndianness(testSig.Address.Bytes()))
// init inputs
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ require (
github.com/ethereum/go-ethereum v1.14.7
github.com/frankban/quicktest v1.14.6
github.com/iden3/go-iden3-crypto v0.0.17
github.com/mdehoog/gnark-circom-smt v0.0.0-20240224081844-6972718e0548
github.com/mdehoog/poseidon v0.0.0-20240301020106-ba6c393a5802
github.com/vocdoni/arbo v0.0.0-20241120112623-8e1cc943f444
github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241113074257-1a711ad38a6b
go.vocdoni.io/dvote v1.10.2-0.20241024102542-c1ce6d744bc5
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdehoog/gnark-circom-smt v0.0.0-20240224081844-6972718e0548 h1:ClFjcjhUHrQpzkSL9Oebv7bXGbGdeIRfOuUM1i41i94=
github.com/mdehoog/gnark-circom-smt v0.0.0-20240224081844-6972718e0548/go.mod h1:tVwUUzmexHsVX93amrWeyKzciTt4ZRoAPlUElK5qmaI=
github.com/mdehoog/poseidon v0.0.0-20240301020106-ba6c393a5802 h1:CfxXsQOJVRzKn9VPe12SUWQQeUxsBsz7b9PA0A+FOgQ=
github.com/mdehoog/poseidon v0.0.0-20240301020106-ba6c393a5802/go.mod h1:UWNkR2GgyHX6Nz8cBJXdUDyTf14UoEB+5YgncI0rsJE=
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY=
Expand Down
2 changes: 1 addition & 1 deletion test/utils.go → testutil/utils.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package gnarkcryptoprimitives
package testutil

import (
"fmt"
Expand Down
4 changes: 2 additions & 2 deletions tree/arbo/verifier_bls12377_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/consensys/gnark/test"
qt "github.com/frankban/quicktest"
arbotree "github.com/vocdoni/arbo"
internaltest "github.com/vocdoni/gnark-crypto-primitives/test"
"github.com/vocdoni/gnark-crypto-primitives/testutil"
"go.vocdoni.io/dvote/util"
)

Expand Down Expand Up @@ -55,7 +55,7 @@ func TestVerifierBLS12377(t *testing.T) {
p.Stop()
fmt.Println("constrains", p.NbConstraints())
// generate census proof
testCensus, err := internaltest.GenerateCensusProofForTest(internaltest.CensusTestConfig{
testCensus, err := testutil.GenerateCensusProofForTest(testutil.CensusTestConfig{
Dir: t.TempDir() + "/bls12377",
ValidSiblings: v_siblings,
TotalSiblings: n_siblings,
Expand Down
4 changes: 2 additions & 2 deletions tree/arbo/verifier_bn254_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
qt "github.com/frankban/quicktest"
arbotree "github.com/vocdoni/arbo"
"github.com/vocdoni/gnark-crypto-primitives/hash/bn254/poseidon"
internaltest "github.com/vocdoni/gnark-crypto-primitives/test"
"github.com/vocdoni/gnark-crypto-primitives/testutil"
"go.vocdoni.io/dvote/util"
)

Expand All @@ -41,7 +41,7 @@ func TestVerifierBN254(t *testing.T) {
p.Stop()
fmt.Println("constrains", p.NbConstraints())
// generate census proof
testCensus, err := internaltest.GenerateCensusProofForTest(internaltest.CensusTestConfig{
testCensus, err := testutil.GenerateCensusProofForTest(testutil.CensusTestConfig{
Dir: t.TempDir() + "/bn254",
ValidSiblings: v_siblings,
TotalSiblings: n_siblings,
Expand Down
20 changes: 20 additions & 0 deletions tree/smt/emulated/hash.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package emulated

import (
"github.com/consensys/gnark/std/math/emulated"

poseidon "github.com/mdehoog/poseidon/circuits/poseidon/emulated"
)

// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smthash_poseidon.circom

func Hash1[T emulated.FieldParams](field *emulated.Field[T], key, value *emulated.Element[T]) *emulated.Element[T] {
one := emulated.ValueOf[T](1)
inputs := []*emulated.Element[T]{key, value, &one}
return poseidon.Hash(field, inputs)
}

func Hash2[T emulated.FieldParams](field *emulated.Field[T], l, r *emulated.Element[T]) *emulated.Element[T] {
inputs := []*emulated.Element[T]{l, r}
return poseidon.Hash(field, inputs)
}
29 changes: 29 additions & 0 deletions tree/smt/emulated/lev_ins.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
)

// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smtlevins.circom

func LevIns[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], enabled frontend.Variable, siblings []*emulated.Element[T]) (levIns []frontend.Variable) {
levels := len(siblings)
levIns = make([]frontend.Variable, levels)
done := make([]frontend.Variable, levels-1)

isZero := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
isZero[i] = field.IsZero(siblings[i])
}
api.AssertIsEqual(api.Mul(api.Sub(isZero[levels-1], 1), enabled), 0)

levIns[levels-1] = api.Sub(1, isZero[levels-2])
done[levels-2] = levIns[levels-1]
for i := levels - 2; i > 0; i-- {
levIns[i] = api.Mul(api.Sub(1, done[i]), api.Sub(1, isZero[i-1]))
done[i-1] = api.Add(levIns[i], done[i])
}
levIns[0] = api.Sub(1, done[0])
return levIns
}
67 changes: 67 additions & 0 deletions tree/smt/emulated/processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"

"github.com/mdehoog/gnark-circom-smt/circuits/smt"
)

// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessor.circom

func Processor[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], oldRoot *emulated.Element[T], siblings []*emulated.Element[T], oldKey, oldValue *emulated.Element[T], isOld0 frontend.Variable, newKey, newValue *emulated.Element[T], fnc0, fnc1 frontend.Variable) (newRoot *emulated.Element[T]) {
levels := len(siblings)
enabled := api.Sub(api.Add(fnc0, fnc1), api.Mul(fnc0, fnc1))
hash1Old := Hash1(field, oldKey, oldValue)
hash1New := Hash1(field, newKey, newValue)
n2bOld := field.ToBits(oldKey)
n2bNew := field.ToBits(newKey)
smtLevIns := LevIns(api, field, enabled, siblings)

xors := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
xors[i] = api.Xor(n2bOld[i], n2bNew[i])
}

stTop := make([]frontend.Variable, levels)
stOld0 := make([]frontend.Variable, levels)
stBot := make([]frontend.Variable, levels)
stNew1 := make([]frontend.Variable, levels)
stNa := make([]frontend.Variable, levels)
stUpd := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
if i == 0 {
stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = smt.ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, enabled, 0, 0, 0, api.Sub(1, enabled), 0)
} else {
stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = smt.ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, stTop[i-1], stOld0[i-1], stBot[i-1], stNew1[i-1], stNa[i-1], stUpd[i-1])
}
}

api.AssertIsEqual(api.Add(api.Add(stNa[levels-1], stNew1[levels-1]), api.Add(stOld0[levels-1], stUpd[levels-1])), 1)

levelsOldRoot := make([]*emulated.Element[T], levels)
levelsNewRoot := make([]*emulated.Element[T], levels)
for i := levels - 1; i >= 0; i-- {
if i == levels-1 {
zero := emulated.ValueOf[T](0)
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, field, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], &zero, &zero)
} else {
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, field, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], levelsOldRoot[i+1], levelsNewRoot[i+1])
}
}

topSwitcherL, topSwitcherR := Switcher(field, api.Mul(fnc0, fnc1), levelsOldRoot[0], levelsNewRoot[0])
ForceEqualIfEnabled(field, oldRoot, topSwitcherL, enabled)

newRoot = field.Select(enabled, topSwitcherR, oldRoot)

areKeyEquals := IsEqual(field, oldKey, newKey)
in := []frontend.Variable{
api.Sub(1, fnc0),
fnc1,
api.Sub(1, areKeyEquals),
}
keysOk := smt.MultiAnd(api, in)
api.AssertIsEqual(keysOk, 0)
return
}
27 changes: 27 additions & 0 deletions tree/smt/emulated/processor_level.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
)

// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessorlevel.circom

func ProcessorLevel[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], stTop, stOld0, stBot, stNew1, stUpd frontend.Variable, sibling, old1leaf, new1leaf *emulated.Element[T], newlrbit frontend.Variable, oldChild, newChild *emulated.Element[T]) (oldRoot, newRoot *emulated.Element[T]) {
oldProofHashL, oldProofHashR := Switcher(field, newlrbit, oldChild, sibling)
oldProofHash := Hash2(field, oldProofHashL, oldProofHashR)

am := api.Add(api.Add(stBot, stNew1), stUpd)
oldRoot = mux2(api, field, am, stTop, old1leaf, oldProofHash)

am = api.Add(stTop, stBot)
a := mux2(api, field, am, stNew1, newChild, new1leaf)
b := mux2(api, field, stTop, stNew1, sibling, old1leaf)
newProofHashL, newProofHashR := Switcher(field, newlrbit, a, b)
newProofHash := Hash2(field, newProofHashL, newProofHashR)

am = api.Add(api.Add(stTop, stBot), stNew1)
bm := api.Add(stOld0, stUpd)
newRoot = mux2(api, field, am, bm, newProofHash, new1leaf)
return
}
34 changes: 34 additions & 0 deletions tree/smt/emulated/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
)

func IsEqual[T emulated.FieldParams](field *emulated.Field[T], a, b *emulated.Element[T]) frontend.Variable {
return field.IsZero(field.Sub(a, b))
}

func ForceEqualIfEnabled[T emulated.FieldParams](field *emulated.Field[T], a, b *emulated.Element[T], enabled frontend.Variable) {
c := field.Select(enabled, a, b)
field.AssertIsEqual(c, b)
}

// Switcher is [out1, out2] = sel ? [r, l] : [l, r]
func Switcher[T emulated.FieldParams](field *emulated.Field[T], sel frontend.Variable, l, r *emulated.Element[T]) (*emulated.Element[T], *emulated.Element[T]) {
return field.Select(sel, r, l), field.Select(sel, l, r)
}

// mux2 is (out = as ? a : bs ? b : 0)
func mux2[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], as, bs frontend.Variable, a, b *emulated.Element[T]) *emulated.Element[T] {
sel := api.FromBinary(as, bs)
zero := emulated.ValueOf[T](0)
return field.Mux(sel, &zero, a, b, a)
}

// mux3 is (out = as ? a : bs ? b : cs ? c : 0)
func mux3[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], as, bs, cs frontend.Variable, a, b, c *emulated.Element[T]) *emulated.Element[T] {
sel := api.FromBinary(as, bs, cs)
zero := emulated.ValueOf[T](0)
return field.Mux(sel, &zero, a, b, a, c, a, b, a)
}
54 changes: 54 additions & 0 deletions tree/smt/emulated/verifier.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"

"github.com/mdehoog/gnark-circom-smt/circuits/smt"
)

func InclusionVerifier[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], root *emulated.Element[T], siblings []*emulated.Element[T], key, value *emulated.Element[T]) {
Verifier[T](api, field, 1, root, siblings, key, value, 0, key, value, 0)
}

func ExclusionVerifier[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], root *emulated.Element[T], siblings []*emulated.Element[T], oldKey, oldValue *emulated.Element[T], isOld0 frontend.Variable, key *emulated.Element[T]) {
zero := emulated.ValueOf[T](0)
Verifier[T](api, field, 1, root, siblings, oldKey, oldValue, isOld0, key, &zero, 1)
}

func Verifier[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], enabled frontend.Variable, root *emulated.Element[T], siblings []*emulated.Element[T], oldKey, oldValue *emulated.Element[T], isOld0 frontend.Variable, key, value *emulated.Element[T], fnc frontend.Variable) {
nLevels := len(siblings)
hash1Old := Hash1(field, oldKey, oldValue)
hash1New := Hash1(field, key, value)
n2bNew := field.ToBits(key)
smtLevIns := LevIns(api, field, enabled, siblings)

stTop := make([]frontend.Variable, nLevels)
stI0 := make([]frontend.Variable, nLevels)
stIOld := make([]frontend.Variable, nLevels)
stINew := make([]frontend.Variable, nLevels)
stNa := make([]frontend.Variable, nLevels)
for i := 0; i < nLevels; i++ {
if i == 0 {
stTop[i], stI0[i], stIOld[i], stINew[i], stNa[i] = smt.VerifierSM(api, isOld0, smtLevIns[i], fnc, enabled, 0, 0, 0, api.Sub(1, enabled))
} else {
stTop[i], stI0[i], stIOld[i], stINew[i], stNa[i] = smt.VerifierSM(api, isOld0, smtLevIns[i], fnc, stTop[i-1], stI0[i-1], stIOld[i-1], stINew[i-1], stNa[i-1])
}
}
api.AssertIsEqual(api.Add(api.Add(api.Add(stNa[nLevels-1], stIOld[nLevels-1]), stINew[nLevels-1]), stI0[nLevels-1]), 1)

levels := make([]*emulated.Element[T], nLevels)
for i := nLevels - 1; i >= 0; i-- {
if i == nLevels-1 {
zero := emulated.ValueOf[T](0)
levels[i] = VerifierLevel(api, field, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], &zero)
} else {
levels[i] = VerifierLevel(api, field, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], levels[i+1])
}
}

areKeyEquals := IsEqual(field, oldKey, key)
keysOk := smt.MultiAnd(api, []frontend.Variable{fnc, api.Sub(1, isOld0), areKeyEquals, enabled})
api.AssertIsEqual(keysOk, 0)
ForceEqualIfEnabled(field, levels[0], root, enabled)
}
13 changes: 13 additions & 0 deletions tree/smt/emulated/verifier_level.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
)

func VerifierLevel[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], stTop, stIOld, stINew frontend.Variable, sibling, old1leaf, new1leaf *emulated.Element[T], lrbit frontend.Variable, child *emulated.Element[T]) (root *emulated.Element[T]) {
proofHashL, proofHashR := Switcher(field, lrbit, child, sibling)
proofHash := Hash2(field, proofHashL, proofHashR)
root = mux3(api, field, stTop, stIOld, stINew, proofHash, old1leaf, new1leaf)
return
}
19 changes: 19 additions & 0 deletions tree/smt/hash.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package smt

import (
"github.com/consensys/gnark/frontend"

"github.com/mdehoog/poseidon/circuits/poseidon"
)

// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smthash_poseidon.circom

func Hash1(api frontend.API, key, value frontend.Variable) frontend.Variable {
inputs := []frontend.Variable{key, value, 1}
return poseidon.Hash(api, inputs)
}

func Hash2(api frontend.API, l, r frontend.Variable) frontend.Variable {
inputs := []frontend.Variable{l, r}
return poseidon.Hash(api, inputs)
}
26 changes: 26 additions & 0 deletions tree/smt/lev_ins.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package smt

import "github.com/consensys/gnark/frontend"

// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smtlevins.circom

func LevIns(api frontend.API, enabled frontend.Variable, siblings []frontend.Variable) (levIns []frontend.Variable) {
levels := len(siblings)
levIns = make([]frontend.Variable, levels)
done := make([]frontend.Variable, levels-1)

isZero := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
isZero[i] = api.IsZero(siblings[i])
}
api.AssertIsEqual(api.Mul(api.Sub(isZero[levels-1], 1), enabled), 0)

levIns[levels-1] = api.Sub(1, isZero[levels-2])
done[levels-2] = levIns[levels-1]
for i := levels - 2; i > 0; i-- {
levIns[i] = api.Mul(api.Sub(1, done[i]), api.Sub(1, isZero[i-1]))
done[i-1] = api.Add(levIns[i], done[i])
}
levIns[0] = api.Sub(1, done[0])
return levIns
}
Loading
Loading