Skip to content

Commit

Permalink
MerkleTreeBN128 to work with any arity power of two
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerTaule committed Mar 11, 2024
1 parent 5bf798f commit 0d13636
Show file tree
Hide file tree
Showing 40 changed files with 676 additions and 334 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/on-pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,33 @@ jobs:

- name: Check C12
run: npm run test:C12
c12-custom-test:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
- uses: actions/setup-node@v3
with:
node-version: '16.17.0'
check-latest: true
cache: "npm"

- name: "Install circom"
run: |
curl https://sh.rustup.rs -sSf -o rust.sh
bash -f rust.sh -y
git clone https://github.com/iden3/circom.git
cd circom
cargo build --release
cargo install --path circom
- name: Install dependencies
run: npm ci

- name: Create tmp directory
run: mkdir tmp

- name: Check C12 with arity 4
run: npm run test:C12:custom

c18-test:
runs-on: ubuntu-20.04
Expand Down
17 changes: 10 additions & 7 deletions circuits.bn128/linearhash.circom
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ include "poseidon.circom";
// Given a list on inputs over GL³, compute the linear hash of the list, mapping from GL³ to BN
// via the map (x,y,z) |-> x + y·2⁶⁴ + z·2¹²⁸, which is injective but not surjective;
// and hashing the resulting BN elements in chunks of 16 using Poseidon.
template LinearHash(nInputs, eSize) {
template LinearHash(nInputs, eSize, arity) {
signal input in[nInputs][eSize];
signal output out;

Expand All @@ -25,19 +25,20 @@ template LinearHash(nInputs, eSize) {
out <== sAc;
nHashes = 0;
} else {
nHashes = (nElements256 - 1)\16 + 1;

nHashes = (nElements256 - 1)\arity +1;
}

component hash[nHashes > 0 ? nHashes - 1 : 0];
var nLastHash;
component lastHash;

for (var i=0; i<nHashes-1; i++) {
hash[i] = PoseidonEx(16, 1);
hash[i] = PoseidonEx(arity, 1);
}

if (nHashes > 0) {
nLastHash = nElements256 - (nHashes - 1)*16;
if (nHashes>0) {
nLastHash = nElements256 - (nHashes - 1)*arity;
lastHash = PoseidonEx(nLastHash, 1);
}

Expand All @@ -58,7 +59,7 @@ template LinearHash(nInputs, eSize) {
sAc = 0;
nAc = 0;
curHashIdx ++;
if (curHashIdx == 16) {
if (curHashIdx == arity) {
curHash++;
curHashIdx = 0;
}
Expand All @@ -72,7 +73,7 @@ template LinearHash(nInputs, eSize) {
hash[curHash].inputs[curHashIdx] <== sAc;
}
curHashIdx ++;
if (curHashIdx == 16) {
if (curHashIdx == arity) {
curHash = 0;
curHashIdx = 0;
}
Expand All @@ -84,13 +85,15 @@ template LinearHash(nInputs, eSize) {
} else {
hash[i].initialState <== hash[i-1].out[0];
}
_ <== hash[i].out;
}
if (nHashes == 1) {
lastHash.initialState <== 0;
} else {
lastHash.initialState <== hash[nHashes-2].out[0];
}

_ <== lastHash.out;
out <== lastHash.out[0];
}
}
Expand Down
86 changes: 26 additions & 60 deletions circuits.bn128/merkle.circom
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
pragma circom 2.1.0;

include "bitify.circom";
include "comparators.circom";
include "poseidon.circom";

/*
Given a leaf value, its sibling path and a key indicating the hashing position for each element in the path, calculate the merkle tree root
- keyBits: number of bits in the key
*/
template Merkle(keyBits) {
var arity = 16;
template Merkle(keyBits, arity) {
var nLevels = 0;
var nBits = log2(arity);
var n = 1 << keyBits;
var nn = n;
while (nn > 1) {
Expand All @@ -21,80 +23,44 @@ template Merkle(keyBits) {
signal input key[keyBits];
signal output root;

signal s[16];
signal a, b, c, d, ab, ac, ad, bc, bd, cd, abc, abd, acd, bcd, abcd;
signal s[arity];

component mNext;
component hash;

component keyNum;

if (nLevels == 0) {
root <== value;
s <== [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0];
(a, b, c, d, ab, ac, ad, bc, bd, cd, abc, abd, acd, bcd, abcd) <== (0,0,0,0,0,0,0,0,0,0,0,0,0,0,0);
} else {
if (keyBits>=1) {
d <== key[0];
} else {
d <== 0;
}
if (keyBits>=2) {
c <== key[1];
} else {
c <== 0;
}
if (keyBits>=3) {
b <== key[2];
} else {
b <== 0;
}
if (keyBits>=4) {
a <== key[3];
} else {
a <== 0;
for(var i = 0; i < arity; i++) {
s[i] <== 0;
}

ab <== a*b;
ac <== a*c;
ad <== a*d;
bc <== b*c;
bd <== b*d;
cd <== c*d;

abc <== ab*c;
abd <== ab*d;
acd <== ac*d;
bcd <== bc*d;

abcd <== ab*cd;

s[0] <== 1-d-c + cd-b + bd + bc-bcd-a + ad + ac-acd + ab-abd-abc + abcd;
s[1] <== d-cd-bd + bcd-ad + acd + abd-abcd;
s[2] <== c-cd-bc + bcd-ac + acd + abc-abcd;
s[3] <== cd-bcd-acd + abcd;
s[4] <== b-bd-bc + bcd-ab + abd + abc-abcd;
s[5] <== bd-bcd-abd + abcd;
s[6] <== bc-bcd-abc + abcd;
s[7] <== bcd-abcd;
s[8] <== a-ad-ac + acd-ab + abd + abc-abcd;
s[9] <== ad-acd-abd + abcd;
s[10] <== ac-acd-abc + abcd;
s[11] <== acd-abcd;
s[12] <== ab-abd-abc + abcd;
s[13] <== abd-abcd;
s[14] <== abc-abcd;
s[15] <== abcd;
} else {
keyNum = Bits2Num(nBits);
for(var i = 0; i < nBits; i++) {
if(keyBits >= i + 1) {
keyNum.in[i] <== key[i];
} else {
keyNum.in[i] <== 0;
}
}

for(var i = 0; i < arity; i++) {
s[i] <== IsEqual()([keyNum.out, i]);
}

hash = Poseidon(arity);

for (var i=0; i<arity; i++) {
hash.inputs[i] <== s[i] * (value - siblings[0][i] ) + siblings[0][i];
hash.inputs[i] <== s[i] * (value - siblings[0][i]) + siblings[0][i];
}

var nextNBits = keyBits - 4;
var nextNBits = keyBits - nBits;
if (nextNBits<0) nextNBits = 0;
var nNext = (n - 1)\arity + 1;

mNext = Merkle(nextNBits);
mNext = Merkle(nextNBits, arity);
mNext.value <== hash.out;

for (var i=0; i<nLevels-1; i++) {
Expand All @@ -104,7 +70,7 @@ template Merkle(keyBits) {
}

for (var i=0; i<nextNBits; i++) {
mNext.key[i] <== key[i+4];
mNext.key[i] <== key[i+nBits];
}

root <== mNext.root;
Expand Down
21 changes: 11 additions & 10 deletions circuits.bn128/merklehash.circom
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ include "utils.circom";
- elementsInLinear: Each leave of the merkle tree is made by this number of values.
- nLinears: Number of leaves of the merkle tree
*/
template MerkleHash(eSize, elementsInLinear, nLinears) {
template MerkleHash(eSize, elementsInLinear, nLinears, arity) {
var nBits = log2(nLinears);
assert(1 << nBits == nLinears);
var nLevels = (nBits - 1)\4 +1;
var logArity = log2(arity);
var nLevels = (nBits - 1)\logArity +1;
signal input values[elementsInLinear][eSize];
signal input siblings[nLevels][16]; // Sibling path to calculate the merkle root given a set of values.
signal input siblings[nLevels][arity]; // Sibling path to calculate the merkle root given a set of values.
signal input key[nBits]; // Defines either each element of the sibling path is the left or right one
signal output root; // Root of the merkle tree

// Each leaf in the merkle tree might be composed by multiple values. Therefore, the first step is to
// reduce all those values into a single one by hashing all of them
signal linearHash <== LinearHash(elementsInLinear, eSize)(values);
signal linearHash <== LinearHash(elementsInLinear, eSize, arity)(values);

// Calculate the merkle root
root <== Merkle(nBits)(linearHash, siblings ,key);
root <== Merkle(nBits, arity)(linearHash, siblings ,key);
}


Expand All @@ -34,18 +34,19 @@ template MerkleHash(eSize, elementsInLinear, nLinears) {
- elementsInLinear: Each leave of the merkle tree is made by this number of values.
- nLinears: Number of leaves of the merkle tree
*/
template parallel VerifyMerkleHash(eSize, elementsInLinear, nLinears) {
template parallel VerifyMerkleHash(eSize, elementsInLinear, nLinears, arity) {
var nLeaves = log2(arity);
var nBits = log2(nLinears);
assert(1 << nBits == nLinears);
var nLevels = (nBits - 1)\4 +1;
var nLevels = (nBits - 1)\nLeaves +1;
signal input values[elementsInLinear][eSize];
signal input siblings[nLevels][16]; // Sibling path to calculate the merkle root given a set of values. Why 16 ???
signal input siblings[nLevels][arity]; // Sibling path to calculate the merkle root given a set of values.
signal input key[nBits]; // Defines either each element of the sibling path is the left or right one
signal input root; // Root of the merkle tree
signal input enable; // Boolean that determines either we want to check that roots matches or not

// Calculate the merkle root
signal merkleRoot <== MerkleHash(eSize, elementsInLinear, nLinears)(values, siblings, key);
signal merkleRoot <== MerkleHash(eSize, elementsInLinear, nLinears, arity)(values, siblings, key);

// If enable is set to 1, check that the merkleRoot being calculated matches with the one sent as input
enable * (merkleRoot - root) === 0;
Expand Down
Loading

0 comments on commit 0d13636

Please sign in to comment.