Skip to content

Commit

Permalink
simpler fc
Browse files Browse the repository at this point in the history
  • Loading branch information
namnc committed Feb 27, 2024
1 parent 4e08035 commit e5597e1
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 41 deletions.
60 changes: 46 additions & 14 deletions src/assets/circuit.circom
Original file line number Diff line number Diff line change
Expand Up @@ -103,30 +103,62 @@ template div_relu(k) {
}

template network() {
var in_len = 3;
var out_len = 5;
signal input in[in_len];
signal output out[out_len];
// var in_len = 3;
// var out_len = 5;
signal input in[3];
signal output out[5];

component l0 = fc(3, 5);
signal input w0[5][3];
signal input b0[5];
l0.weights <== w0;
l0.biases <== b0;
l0.in <== in;
for (var i = 0; i < 5; i++) {
for (var j = 0; j < 3; j++) {
l0.weights[i][j] <== w0[i][j];
}
l0.biases[i] <== b0[i];
}
// l0.weights <== w0;
// l0.biases <== b0;
for (var k = 0; k < 3; k++) {
l0.in[k] <== in[k];
}

component l1 = fc(5, 7);
signal input w1[7][5];
signal input b1[7];
l1.weights <== w1;
l1.biases <== b1;
l1.in <== l0.out;
for (var i = 0; i < 7; i++) {
for (var j = 0; j < 5; j++) {
l1.weights[i][j] <== w1[i][j];
}
l1.biases[i] <== b1[i];
}
// l1.weights <== w1;
// l1.biases <== b1;
for (var k = 0; k < 5; k++) {
l1.in[k] <== l0.out[k];
}
// l1.in <== l0.out;

component l2 = fc_no_relu(7, 5);
signal input w2[5][7];
signal input b2[5];
l2.weights <== w2;
l2.biases <== b2;
l2.in <== l1.out;
out <== l2.out;
for (var i = 0; i < 5; i++) {
for (var j = 0; j < 7; j++) {
l2.weights[i][j] <== w2[i][j];
}
l2.biases[i] <== b2[i];
}
// l2.weights <== w2;
// l2.biases <== b2;
for (var k = 0; k < 7; k++) {
l2.in[k] <== l1.out[k];
}
// l2.in <== l1.out;

for (var k = 0; k < 5; k++) {
out[k] <== l2.out[k];
}
// out <== l2.out;
}

component main = network();
86 changes: 59 additions & 27 deletions src/assets/fc.circom
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ template Switcher() {

aux <== (R-L)*sel; // We create aux in order to have only one multiplication
outL <== aux + L;
outR <== -aux + R;
outR <== R - aux;
}


Expand All @@ -32,7 +32,7 @@ template fc (width, height) {
rows[index].weight_vector[index_input] <== weights[index][index_input];
}
rows[index].bias <== biases[index];
relu[index] = div_relu(128, 12);
relu[index] = div_relu(12);
relu[index].in <== rows[index].out;
out[index] <== relu[index].out;
}
Expand Down Expand Up @@ -65,7 +65,7 @@ template dot_product (width) {
signal output out;

inter_accum[0] <== inputs[0]*weight_vector[0];
inter_accum[0]*0 === 0;
// inter_accum[0]*0 === 0;

for(var index = 1; index < width; index++) {
inter_accum[index] <== inputs[index]*weight_vector[index] + inter_accum[index-1];
Expand Down Expand Up @@ -103,30 +103,62 @@ template div_relu(k) {
}

template network() {
var in_len = 32;
var out_len = 100;
signal input in[in_len];
signal output out[out_len];

component l0 = fc(32, 100);
signal input w0[100][32];
signal input b0[100];
l0.weights <== w0;
l0.biases <== b0;
l0.in <== in;
component l1 = fc(100, 200);
signal input w1[200][100];
signal input b1[200];
l1.weights <== w1;
l1.biases <== b1;
l1.in <== l0.out;
component l2 = fc_no_relu(200, 100);
signal input w2[100][200];
signal input b2[100];
l2.weights <== w2;
l2.biases <== b2;
l2.in <== l1.out;
out <== l2.out;
// var in_len = 3;
// var out_len = 5;
signal input in[3];
signal output out[5];

component l0 = fc(3, 5);
signal input w0[5][3];
signal input b0[5];
for (var i = 0; i < 5; i++) {
for (var j = 0; j < 3; j++) {
l0.weights[i][j] <== w0[i][j];
}
l0.biases[i] <== b0[i];
}
// l0.weights <== w0;
// l0.biases <== b0;
for (var k = 0; k < 3; k++) {
l0.in[k] <== in[k];
}

component l1 = fc(5, 7);
signal input w1[7][5];
signal input b1[7];
for (var i = 0; i < 7; i++) {
for (var j = 0; j < 5; j++) {
l1.weights[i][j] <== w1[i][j];
}
l1.biases[i] <== b1[i];
}
// l1.weights <== w1;
// l1.biases <== b1;
for (var k = 0; k < 5; k++) {
l1.in[k] <== l0.out[k];
}
// l1.in <== l0.out;

component l2 = fc_no_relu(7, 5);
signal input w2[5][7];
signal input b2[5];
for (var i = 0; i < 5; i++) {
for (var j = 0; j < 7; j++) {
l2.weights[i][j] <== w2[i][j];
}
l2.biases[i] <== b2[i];
}
// l2.weights <== w2;
// l2.biases <== b2;
for (var k = 0; k < 7; k++) {
l2.in[k] <== l1.out[k];
}
// l2.in <== l1.out;

for (var k = 0; k < 5; k++) {
out[k] <== l2.out[k];
}
// out <== l2.out;
}

component main = network();

0 comments on commit e5597e1

Please sign in to comment.