Skip to content

Commit

Permalink
fc 8 layers
Browse files Browse the repository at this point in the history
  • Loading branch information
namnc committed Mar 8, 2024
1 parent 3e778f2 commit c9e83e9
Show file tree
Hide file tree
Showing 5 changed files with 674 additions and 96 deletions.
202 changes: 136 additions & 66 deletions src/assets/circuit.circom
Original file line number Diff line number Diff line change
Expand Up @@ -103,108 +103,178 @@ template div_relu(k) {
out <== switcher.outL;
}

template test() {
signal input in[2];
signal output out[3];

component l0 = fc(2, 3);
signal input w0[3][2];
signal input b0[3];
for (var i = 0; i < 3; i++) {
for (var j = 0; j < 2; j++) {
l0.weights[i][j] <== w0[i][j];
}
l0.biases[i] <== b0[i];
}
// l0.weights <== w0;
// l0.biases <== b0;
for (var k = 0; k < 2; k++) {
l0.in[k] <== in[k];
}
template network() {
// Structure from python example
// self.fc1 = nn.Linear(2, 32)
// self.fc2 = nn.Linear(32, 64)
// self.fc3 = nn.Linear(64, 128)
// self.fc4 = nn.Linear(128, 4)

for (var k = 0; k < 3; k++) {
out[k] <== l0.out[k];
}
}
var in_len = 2;
var out_len = 4;

template network() {
// var in_len = 3;
// var out_len = 5;
signal input in[3];
signal output out[5];

component l0 = fc(3, 5);
signal input w0[5][3];
signal input b0[5];
for (var i = 0; i < 5; i++) {
for (var j = 0; j < 3; j++) {
var l0_w = in_len;
var l0_h = 32;

var l1_w = l0_h;
var l1_h = 64;

var l2_w = l1_h;
var l2_h = 128;

var l3_w = l2_h;
var l3_h = 256;

var l4_w = l3_h;
var l4_h = 512;

var l5_w = l4_h;
var l5_h = 1024;

var l6_w = l5_h;
var l6_h = 2048;

var l7_w = l6_h;
var l7_h = out_len;

signal input in[in_len];
signal output out[out_len];

component l0 = fc(l0_w, l0_h);
signal input w0[l0_h][l0_w];
signal input b0[l0_h];
for (var i = 0; i < l0_h; i++) {
for (var j = 0; j < l0_w; j++) {
l0.weights[i][j] <== w0[i][j];
}
l0.biases[i] <== b0[i];
}
// l0.weights <== w0;
// l0.biases <== b0;
for (var k = 0; k < 3; k++) {
for (var k = 0; k < in_len; k++) {
l0.in[k] <== in[k];
}

component l1 = fc(5, 7);
signal input w1[7][5];
signal input b1[7];
for (var i = 0; i < 7; i++) {
for (var j = 0; j < 5; j++) {
component l1 = fc(l1_w, l1_h);
signal input w1[l1_h][l1_w];
signal input b1[l1_h];
for (var i = 0; i < l1_h; i++) {
for (var j = 0; j < l1_w; j++) {
l1.weights[i][j] <== w1[i][j];
}
l1.biases[i] <== b1[i];
}
// l1.weights <== w1;
// l1.biases <== b1;
for (var k = 0; k < 5; k++) {
for (var k = 0; k < l0_h; k++) {
l1.in[k] <== l0.out[k];
}
// l1.in <== l0.out;

component l2 = fc_no_relu(7, 5);
signal input w2[5][7];
signal input b2[5];
for (var i = 0; i < 5; i++) {
for (var j = 0; j < 7; j++) {
component l2 = fc(l2_w, l2_h);
signal input w2[l2_h][l2_w];
signal input b2[l2_h];
for (var i = 0; i < l2_h; i++) {
for (var j = 0; j < l2_w; j++) {
l2.weights[i][j] <== w2[i][j];
}
l2.biases[i] <== b2[i];
}
// l2.weights <== w2;
// l2.biases <== b2;
for (var k = 0; k < 7; k++) {
for (var k = 0; k < l1_h; k++) {
l2.in[k] <== l1.out[k];
}
// l2.in <== l1.out;

for (var k = 0; k < 5; k++) {
out[k] <== l2.out[k];
component l3 = fc(l3_w, l3_h);
signal input w3[l3_h][l3_w];
signal input b3[l3_h];
for (var i = 0; i < l3_h; i++) {
for (var j = 0; j < l3_w; j++) {
l3.weights[i][j] <== w3[i][j];
}
l3.biases[i] <== b3[i];
}
// out <== l2.out;
}

// component main = test();
// component main = network();

template InnerProd () {
// l2.weights <== w2;
// l2.biases <== b2;
for (var k = 0; k < l2_h; k++) {
l3.in[k] <== l2.out[k];
}
// l2.in <== l1.out;

// Declaration of signals
signal input input_A[3];
signal input input_B[3];
signal output ip;
component l4 = fc(l4_w, l4_h);
signal input w4[l4_h][l4_w];
signal input b4[l4_h];
for (var i = 0; i < l4_h; i++) {
for (var j = 0; j < l4_w; j++) {
l4.weights[i][j] <== w4[i][j];
}
l4.biases[i] <== b4[i];
}
for (var k = 0; k < l4_h; k++) {
l4.in[k] <== l3.out[k];
}

signal sum[3];
component l5 = fc(l5_w, l5_h);
signal input w5[l5_h][l5_w];
signal input b5[l5_h];
for (var i = 0; i < l5_h; i++) {
for (var j = 0; j < l5_w; j++) {
l5.weights[i][j] <== w5[i][j];
}
l5.biases[i] <== b5[i];
}
for (var k = 0; k < l5_h; k++) {
l5.in[k] <== l4.out[k];
}

sum[0] <== input_A[0]*input_B[0];
component l6 = fc(l6_w, l6_h);
signal input w6[l6_h][l6_w];
signal input b6[l6_h];
for (var i = 0; i < l6_h; i++) {
for (var j = 0; j < l6_w; j++) {
l6.weights[i][j] <== w6[i][j];
}
l6.biases[i] <== b6[i];
}
for (var k = 0; k < l6_h; k++) {
l6.in[k] <== l5.out[k];
}

for (var i = 1; i < 3; i++) {
sum[i] <== sum[i-1] + input_A[i] * input_B[i];
}
component l7 = fc(l7_w, l7_h);
signal input w7[l7_h][l7_w];
signal input b7[l7_h];
for (var i = 0; i < l7_h; i++) {
for (var j = 0; j < l7_w; j++) {
l7.weights[i][j] <== w7[i][j];
}
l7.biases[i] <== b7[i];
}
for (var k = 0; k < l7_h; k++) {
l7.in[k] <== l6.out[k];
}

ip <== sum[2];
// component l8 = fc_no_relu(l8_w, l8_h);
// signal input w8[l8_h][l8_w];
// signal input b8[l8_h];
// for (var i = 0; i < l8_h; i++) {
// for (var j = 0; j < l8_w; j++) {
// l8.weights[i][j] <== w8[i][j];
// }
// l8.biases[i] <== b8[i];
// }
// // l3.weights <== w2;
// // l3.biases <== b2;
// for (var k = 0; k < l8_h; k++) {
// l8.in[k] <== l7.out[k];
// }
// // l3.in <== l1.out;

for (var k = 0; k < out_len; k++) {
out[k] <== l7.out[k];
}
// out <== l2.out;
}

component main = InnerProd();
component main = network();
Loading

0 comments on commit c9e83e9

Please sign in to comment.