diff --git a/src/assets/circuit.circom b/src/assets/circuit.circom index 22e40ec..87ae4de 100644 --- a/src/assets/circuit.circom +++ b/src/assets/circuit.circom @@ -103,108 +103,178 @@ template div_relu(k) { out <== switcher.outL; } -template test() { - signal input in[2]; - signal output out[3]; - - component l0 = fc(2, 3); - signal input w0[3][2]; - signal input b0[3]; - for (var i = 0; i < 3; i++) { - for (var j = 0; j < 2; j++) { - l0.weights[i][j] <== w0[i][j]; - } - l0.biases[i] <== b0[i]; - } - // l0.weights <== w0; - // l0.biases <== b0; - for (var k = 0; k < 2; k++) { - l0.in[k] <== in[k]; - } +template network() { + // Structure from python example + // self.fc1 = nn.Linear(2, 32) + // self.fc2 = nn.Linear(32, 64) + // self.fc3 = nn.Linear(64, 128) + // self.fc4 = nn.Linear(128, 4) - for (var k = 0; k < 3; k++) { - out[k] <== l0.out[k]; - } -} + var in_len = 2; + var out_len = 4; -template network() { - // var in_len = 3; - // var out_len = 5; - signal input in[3]; - signal output out[5]; - - component l0 = fc(3, 5); - signal input w0[5][3]; - signal input b0[5]; - for (var i = 0; i < 5; i++) { - for (var j = 0; j < 3; j++) { + var l0_w = in_len; + var l0_h = 32; + + var l1_w = l0_h; + var l1_h = 64; + + var l2_w = l1_h; + var l2_h = 128; + + var l3_w = l2_h; + var l3_h = 256; + + var l4_w = l3_h; + var l4_h = 512; + + var l5_w = l4_h; + var l5_h = 1024; + + var l6_w = l5_h; + var l6_h = 2048; + + var l7_w = l6_h; + var l7_h = out_len; + + signal input in[in_len]; + signal output out[out_len]; + + component l0 = fc(l0_w, l0_h); + signal input w0[l0_h][l0_w]; + signal input b0[l0_h]; + for (var i = 0; i < l0_h; i++) { + for (var j = 0; j < l0_w; j++) { l0.weights[i][j] <== w0[i][j]; } l0.biases[i] <== b0[i]; } // l0.weights <== w0; // l0.biases <== b0; - for (var k = 0; k < 3; k++) { + for (var k = 0; k < in_len; k++) { l0.in[k] <== in[k]; } - component l1 = fc(5, 7); - signal input w1[7][5]; - signal input b1[7]; - for (var i = 0; i < 7; i++) { - for (var j = 0; j < 5; j++) { + component l1 = fc(l1_w, l1_h); + signal input w1[l1_h][l1_w]; + signal input b1[l1_h]; + for (var i = 0; i < l1_h; i++) { + for (var j = 0; j < l1_w; j++) { l1.weights[i][j] <== w1[i][j]; } l1.biases[i] <== b1[i]; } // l1.weights <== w1; // l1.biases <== b1; - for (var k = 0; k < 5; k++) { + for (var k = 0; k < l0_h; k++) { l1.in[k] <== l0.out[k]; } // l1.in <== l0.out; - component l2 = fc_no_relu(7, 5); - signal input w2[5][7]; - signal input b2[5]; - for (var i = 0; i < 5; i++) { - for (var j = 0; j < 7; j++) { + component l2 = fc(l2_w, l2_h); + signal input w2[l2_h][l2_w]; + signal input b2[l2_h]; + for (var i = 0; i < l2_h; i++) { + for (var j = 0; j < l2_w; j++) { l2.weights[i][j] <== w2[i][j]; } l2.biases[i] <== b2[i]; } // l2.weights <== w2; // l2.biases <== b2; - for (var k = 0; k < 7; k++) { + for (var k = 0; k < l1_h; k++) { l2.in[k] <== l1.out[k]; } // l2.in <== l1.out; - for (var k = 0; k < 5; k++) { - out[k] <== l2.out[k]; + component l3 = fc(l3_w, l3_h); + signal input w3[l3_h][l3_w]; + signal input b3[l3_h]; + for (var i = 0; i < l3_h; i++) { + for (var j = 0; j < l3_w; j++) { + l3.weights[i][j] <== w3[i][j]; + } + l3.biases[i] <== b3[i]; } - // out <== l2.out; -} - -// component main = test(); -// component main = network(); - -template InnerProd () { + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l2_h; k++) { + l3.in[k] <== l2.out[k]; + } + // l2.in <== l1.out; - // Declaration of signals - signal input input_A[3]; - signal input input_B[3]; - signal output ip; + component l4 = fc(l4_w, l4_h); + signal input w4[l4_h][l4_w]; + signal input b4[l4_h]; + for (var i = 0; i < l4_h; i++) { + for (var j = 0; j < l4_w; j++) { + l4.weights[i][j] <== w4[i][j]; + } + l4.biases[i] <== b4[i]; + } + for (var k = 0; k < l4_h; k++) { + l4.in[k] <== l3.out[k]; + } - signal sum[3]; + component l5 = fc(l5_w, l5_h); + signal input w5[l5_h][l5_w]; + signal input b5[l5_h]; + for (var i = 0; i < l5_h; i++) { + for (var j = 0; j < l5_w; j++) { + l5.weights[i][j] <== w5[i][j]; + } + l5.biases[i] <== b5[i]; + } + for (var k = 0; k < l5_h; k++) { + l5.in[k] <== l4.out[k]; + } - sum[0] <== input_A[0]*input_B[0]; + component l6 = fc(l6_w, l6_h); + signal input w6[l6_h][l6_w]; + signal input b6[l6_h]; + for (var i = 0; i < l6_h; i++) { + for (var j = 0; j < l6_w; j++) { + l6.weights[i][j] <== w6[i][j]; + } + l6.biases[i] <== b6[i]; + } + for (var k = 0; k < l6_h; k++) { + l6.in[k] <== l5.out[k]; + } - for (var i = 1; i < 3; i++) { - sum[i] <== sum[i-1] + input_A[i] * input_B[i]; - } + component l7 = fc(l7_w, l7_h); + signal input w7[l7_h][l7_w]; + signal input b7[l7_h]; + for (var i = 0; i < l7_h; i++) { + for (var j = 0; j < l7_w; j++) { + l7.weights[i][j] <== w7[i][j]; + } + l7.biases[i] <== b7[i]; + } + for (var k = 0; k < l7_h; k++) { + l7.in[k] <== l6.out[k]; + } - ip <== sum[2]; + // component l8 = fc_no_relu(l8_w, l8_h); + // signal input w8[l8_h][l8_w]; + // signal input b8[l8_h]; + // for (var i = 0; i < l8_h; i++) { + // for (var j = 0; j < l8_w; j++) { + // l8.weights[i][j] <== w8[i][j]; + // } + // l8.biases[i] <== b8[i]; + // } + // // l3.weights <== w2; + // // l3.biases <== b2; + // for (var k = 0; k < l8_h; k++) { + // l8.in[k] <== l7.out[k]; + // } + // // l3.in <== l1.out; + + for (var k = 0; k < out_len; k++) { + out[k] <== l7.out[k]; + } + // out <== l2.out; } -component main = InnerProd(); \ No newline at end of file +component main = network(); \ No newline at end of file diff --git a/src/assets/circuit_bak.circom b/src/assets/circuit_bak.circom new file mode 100644 index 0000000..5a36677 --- /dev/null +++ b/src/assets/circuit_bak.circom @@ -0,0 +1,274 @@ +pragma circom 2.0.0; + +template Switcher() { + signal input sel; + signal input L; + signal input R; + signal output outL; + signal output outR; + + signal aux; + + aux <== (R-L)*sel; // We create aux in order to have only one multiplication + outL <== aux + L; + outR <== R - aux; +} + +template auction(n) { + signal input in[n]; + signal signs[n-1]; + signal maxidx[n-1]; + signal maxprice[n-1]; + signal output idx; + signal output price; + component sws[n-1]; + component sws2[n-1]; + + for (var i = 0; i < n-1; i++) { + signs[i] <== in[i] < in[i+1]; + sws[i] <== Switcher(); + sws[i].sel <== signs[i]; + sws[i].L <== in[i]; + sws[i].R <== in[i+1]; + maxprice[i] <== sws[i].outR; + sws2[i] <== Switcher(); + sws2[i].sel <== signs[i]; + sws2[i].L <== i; + sws2[i].R <== i+1; + maxidx[i] <== sws2[i].outR; + } + + idx <== maxidx[n-1]; + price <== maxprice[n-1]; +} + + +template fc (width, height) { + signal input in[width]; + signal input weights[height][width]; + signal input biases[height]; + signal output out[height]; + + component rows[height]; + + component relu[height]; + + for(var index = 0; index < height; index++) { + rows[index] = dot_product(width); + for(var index_input = 0; index_input < width; index_input++) { + rows[index].inputs[index_input] <== in[index_input]; + rows[index].weight_vector[index_input] <== weights[index][index_input]; + } + rows[index].bias <== biases[index]; + relu[index] = div_relu(12); + relu[index].in <== rows[index].out; + out[index] <== relu[index].out; + } +} + +template fc_no_relu (width, height) { + signal input in[width]; + signal input weights[height][width]; + signal input biases[height]; + signal output out[height]; + + component rows[height]; + + for(var index = 0; index < height; index++) { + rows[index] = dot_product(width); + for(var index_input = 0; index_input < width; index_input++) { + rows[index].inputs[index_input] <== in[index_input]; + rows[index].weight_vector[index_input] <== weights[index][index_input]; + } + rows[index].bias <== biases[index]; + out[index] <== rows[index].out; + } +} + +template dot_product (width) { + signal input inputs[width]; + signal input weight_vector[width]; + signal inter_accum[width]; + signal input bias; + signal output out; + + inter_accum[0] <== inputs[0]*weight_vector[0]; + // inter_accum[0]*0 === 0; + + for(var index = 1; index < width; index++) { + inter_accum[index] <== inputs[index]*weight_vector[index] + inter_accum[index-1]; + } + out <== inter_accum[width-1] + bias; +} + +template ShiftRight(k) { + signal input in; + signal output out; + out <== in; +} + +template Sign() { + signal input in; + signal output sign; + sign <== in < 0; +} + +template div_relu(k) { + signal input in; + signal output out; + component shiftRight = ShiftRight(k); + component sign = Sign(); + + shiftRight.in <== in; + sign.in <== shiftRight.out; + + component switcher = Switcher(); + switcher.sel <== sign.sign; + switcher.L <== shiftRight.out; + switcher.R <== 0; + //switcher.outR*0 === 0; + + out <== switcher.outL; +} + +template test() { + signal input in[2]; + signal output out[3]; + + component l0 = fc(2, 3); + signal input w0[3][2]; + signal input b0[3]; + for (var i = 0; i < 3; i++) { + for (var j = 0; j < 2; j++) { + l0.weights[i][j] <== w0[i][j]; + } + l0.biases[i] <== b0[i]; + } + // l0.weights <== w0; + // l0.biases <== b0; + for (var k = 0; k < 2; k++) { + l0.in[k] <== in[k]; + } + + for (var k = 0; k < 3; k++) { + out[k] <== l0.out[k]; + } +} + +template network() { + // Structure from python example + // self.fc1 = nn.Linear(2, 32) + // self.fc2 = nn.Linear(32, 64) + // self.fc3 = nn.Linear(64, 128) + // self.fc4 = nn.Linear(128, 4) + + var in_len = 2; + var out_len = 4; + + var l0_w = in_len; + var l0_h = 32; + + var l1_w = l0_h; + var l1_h = 64; + + var l2_w = l1_h; + var l2_h = 128; + + var l3_w = l2_h; + var l3_h = out_len; + + signal input in[in_len]; + signal output out[out_len]; + + component l0 = fc(l0_w, l0_h); + signal input w0[l0_h][l0_w]; + signal input b0[l0_h]; + for (var i = 0; i < l0_h; i++) { + for (var j = 0; j < l0_w; j++) { + l0.weights[i][j] <== w0[i][j]; + } + l0.biases[i] <== b0[i]; + } + // l0.weights <== w0; + // l0.biases <== b0; + for (var k = 0; k < in_len; k++) { + l0.in[k] <== in[k]; + } + + component l1 = fc(l1_w, l1_h); + signal input w1[l1_h][l1_w]; + signal input b1[l1_h]; + for (var i = 0; i < l1_h; i++) { + for (var j = 0; j < l1_w; j++) { + l1.weights[i][j] <== w1[i][j]; + } + l1.biases[i] <== b1[i]; + } + // l1.weights <== w1; + // l1.biases <== b1; + for (var k = 0; k < l0_h; k++) { + l1.in[k] <== l0.out[k]; + } + // l1.in <== l0.out; + + component l2 = fc(l2_w, l2_h); + signal input w2[l2_h][l2_w]; + signal input b2[l2_h]; + for (var i = 0; i < l2_h; i++) { + for (var j = 0; j < l2_w; j++) { + l2.weights[i][j] <== w2[i][j]; + } + l2.biases[i] <== b2[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l1_h; k++) { + l2.in[k] <== l1.out[k]; + } + // l2.in <== l1.out; + + component l3 = fc_no_relu(l3_w, l3_h); + signal input w3[l3_h][l3_w]; + signal input b3[l3_h]; + for (var i = 0; i < l3_h; i++) { + for (var j = 0; j < l3_w; j++) { + l3.weights[i][j] <== w3[i][j]; + } + l3.biases[i] <== b3[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l2_h; k++) { + l3.in[k] <== l2.out[k]; + } + // l2.in <== l1.out; + + for (var k = 0; k < out_len; k++) { + out[k] <== l3.out[k]; + } + // out <== l2.out; +} + +// template InnerProd () { + +// // Declaration of signals +// signal input input_A[3]; +// signal input input_B[3]; +// signal output ip; + +// signal sum[3]; + +// sum[0] <== input_A[0]*input_B[0]; + +// for (var i = 1; i < 3; i++) { +// sum[i] <== sum[i-1] + input_A[i] * input_B[i]; +// } + +// ip <== sum[2]; +// } + +// component main = InnerProd(); + +// component main = test(); +component main = network(); + diff --git a/src/assets/fc.circom b/src/assets/fc.circom index ea5113d..1824a7b 100644 --- a/src/assets/fc.circom +++ b/src/assets/fc.circom @@ -103,60 +103,95 @@ template div_relu(k) { } template network() { - // var in_len = 3; - // var out_len = 5; - signal input in[3]; - signal output out[5]; - - component l0 = fc(3, 5); - signal input w0[5][3]; - signal input b0[5]; - for (var i = 0; i < 5; i++) { - for (var j = 0; j < 3; j++) { + // Structure from python example + // self.fc1 = nn.Linear(2, 32) + // self.fc2 = nn.Linear(32, 64) + // self.fc3 = nn.Linear(64, 128) + // self.fc4 = nn.Linear(128, 4) + + var in_len = 2; + var out_len = 4; + + var l0_w = in_len; + var l0_h = 32; + + var l1_w = l0_h; + var l1_h = 64; + + var l2_w = l1_h; + var l2_h = 128; + + var l3_w = l2_h; + var l3_h = out_len; + + signal input in[in_len]; + signal output out[out_len]; + + component l0 = fc(l0_w, l0_h); + signal input w0[l0_h][l0_w]; + signal input b0[l0_h]; + for (var i = 0; i < l0_h; i++) { + for (var j = 0; j < l0_w; j++) { l0.weights[i][j] <== w0[i][j]; } l0.biases[i] <== b0[i]; } // l0.weights <== w0; // l0.biases <== b0; - for (var k = 0; k < 3; k++) { + for (var k = 0; k < in_len; k++) { l0.in[k] <== in[k]; } - component l1 = fc(5, 7); - signal input w1[7][5]; - signal input b1[7]; - for (var i = 0; i < 7; i++) { - for (var j = 0; j < 5; j++) { + component l1 = fc(l1_w, l1_h); + signal input w1[l1_h][l1_w]; + signal input b1[l1_h]; + for (var i = 0; i < l1_h; i++) { + for (var j = 0; j < l1_w; j++) { l1.weights[i][j] <== w1[i][j]; } l1.biases[i] <== b1[i]; } // l1.weights <== w1; // l1.biases <== b1; - for (var k = 0; k < 5; k++) { + for (var k = 0; k < l0_h; k++) { l1.in[k] <== l0.out[k]; } // l1.in <== l0.out; - component l2 = fc_no_relu(7, 5); - signal input w2[5][7]; - signal input b2[5]; - for (var i = 0; i < 5; i++) { - for (var j = 0; j < 7; j++) { + component l2 = fc(l2_w, l2_h); + signal input w2[l2_h][l2_w]; + signal input b2[l2_h]; + for (var i = 0; i < l2_h; i++) { + for (var j = 0; j < l2_w; j++) { l2.weights[i][j] <== w2[i][j]; } l2.biases[i] <== b2[i]; } // l2.weights <== w2; // l2.biases <== b2; - for (var k = 0; k < 7; k++) { + for (var k = 0; k < l1_h; k++) { l2.in[k] <== l1.out[k]; } // l2.in <== l1.out; - for (var k = 0; k < 5; k++) { - out[k] <== l2.out[k]; + component l3 = fc_no_relu(l3_w, l3_h); + signal input w3[l3_h][l3_w]; + signal input b3[l3_h]; + for (var i = 0; i < l3_h; i++) { + for (var j = 0; j < l3_w; j++) { + l3.weights[i][j] <== w3[i][j]; + } + l3.biases[i] <== b3[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l2_h; k++) { + l3.in[k] <== l2.out[k]; + } + // l2.in <== l1.out; + + for (var k = 0; k < out_len; k++) { + out[k] <== l3.out[k]; } // out <== l2.out; } diff --git a/src/assets/fc_lite.circom b/src/assets/fc_lite.circom new file mode 100644 index 0000000..2ed26a4 --- /dev/null +++ b/src/assets/fc_lite.circom @@ -0,0 +1,199 @@ +pragma circom 2.0.0; + +template Switcher() { + signal input sel; + signal input L; + signal input R; + signal output outL; + signal output outR; + + signal aux; + + aux <== (R-L)*sel; // We create aux in order to have only one multiplication + outL <== aux + L; + outR <== R - aux; +} + + +template fc (width, height) { + signal input in[width]; + signal input weights[height][width]; + signal input biases[height]; + signal output out[height]; + + component rows[height]; + + component relu[height]; + + for(var index = 0; index < height; index++) { + rows[index] = dot_product(width); + for(var index_input = 0; index_input < width; index_input++) { + rows[index].inputs[index_input] <== in[index_input]; + rows[index].weight_vector[index_input] <== weights[index][index_input]; + } + rows[index].bias <== biases[index]; + relu[index] = div_relu(12); + relu[index].in <== rows[index].out; + out[index] <== relu[index].out; + } +} + +template fc_no_relu (width, height) { + signal input in[width]; + signal input weights[height][width]; + signal input biases[height]; + signal output out[height]; + + component rows[height]; + + for(var index = 0; index < height; index++) { + rows[index] = dot_product(width); + for(var index_input = 0; index_input < width; index_input++) { + rows[index].inputs[index_input] <== in[index_input]; + rows[index].weight_vector[index_input] <== weights[index][index_input]; + } + rows[index].bias <== biases[index]; + out[index] <== rows[index].out; + } +} + +template dot_product (width) { + signal input inputs[width]; + signal input weight_vector[width]; + signal inter_accum[width]; + signal input bias; + signal output out; + + inter_accum[0] <== inputs[0]*weight_vector[0]; + // inter_accum[0]*0 === 0; + + for(var index = 1; index < width; index++) { + inter_accum[index] <== inputs[index]*weight_vector[index] + inter_accum[index-1]; + } + out <== inter_accum[width-1] + bias; +} + +template ShiftRight(k) { + signal input in; + signal output out; + out <== in; +} + +template Sign() { + signal input in; + signal output sign; +} + +template div_relu(k) { + signal input in; + signal output out; + component shiftRight = ShiftRight(k); + component sign = Sign(); + + shiftRight.in <== in; + sign.in <== shiftRight.out; + + component switcher = Switcher(); + switcher.sel <== sign.sign; + switcher.L <== shiftRight.out; + switcher.R <== 0; + //switcher.outR*0 === 0; + + out <== switcher.outL; +} + +template network() { + // Structure from python example + // self.fc1 = nn.Linear(2, 32) + // self.fc2 = nn.Linear(32, 64) + // self.fc3 = nn.Linear(64, 128) + // self.fc4 = nn.Linear(128, 4) + + var in_len = 2; + var out_len = 4; + + var l0_w = in_len; + var l0_h = 5; + + var l1_w = l0_h; + var l1_h = 7; + + var l2_w = l1_h; + var l2_h = 11; + + var l3_w = l2_h; + var l3_h = out_len; + + signal input in[in_len]; + signal output out[out_len]; + + component l0 = fc(l0_w, l0_h); + signal input w0[l0_h][l0_w]; + signal input b0[l0_h]; + for (var i = 0; i < l0_h; i++) { + for (var j = 0; j < l0_w; j++) { + l0.weights[i][j] <== w0[i][j]; + } + l0.biases[i] <== b0[i]; + } + // l0.weights <== w0; + // l0.biases <== b0; + for (var k = 0; k < in_len; k++) { + l0.in[k] <== in[k]; + } + + component l1 = fc(l1_w, l1_h); + signal input w1[l1_h][l1_w]; + signal input b1[l1_h]; + for (var i = 0; i < l1_h; i++) { + for (var j = 0; j < l1_w; j++) { + l1.weights[i][j] <== w1[i][j]; + } + l1.biases[i] <== b1[i]; + } + // l1.weights <== w1; + // l1.biases <== b1; + for (var k = 0; k < l0_h; k++) { + l1.in[k] <== l0.out[k]; + } + // l1.in <== l0.out; + + component l2 = fc(l2_w, l2_h); + signal input w2[l2_h][l2_w]; + signal input b2[l2_h]; + for (var i = 0; i < l2_h; i++) { + for (var j = 0; j < l2_w; j++) { + l2.weights[i][j] <== w2[i][j]; + } + l2.biases[i] <== b2[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l1_h; k++) { + l2.in[k] <== l1.out[k]; + } + // l2.in <== l1.out; + + component l3 = fc_no_relu(l3_w, l3_h); + signal input w3[l3_h][l3_w]; + signal input b3[l3_h]; + for (var i = 0; i < l3_h; i++) { + for (var j = 0; j < l3_w; j++) { + l3.weights[i][j] <== w3[i][j]; + } + l3.biases[i] <== b3[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l2_h; k++) { + l3.in[k] <== l2.out[k]; + } + // l2.in <== l1.out; + + for (var k = 0; k < out_len; k++) { + out[k] <== l3.out[k]; + } + // out <== l2.out; +} + +component main = network(); \ No newline at end of file diff --git a/src/circuit.rs b/src/circuit.rs index c6f2e2f..6b2297f 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -208,7 +208,7 @@ impl ArithmeticCircuit { match gate_type { AGateType::AAdd => { - println!("{} = {} + {}", o_name, lh_name, rh_name); + // println!("{} = {} + {}", o_name, lh_name, rh_name); }, AGateType::ADiv => todo!(), AGateType::AEq => todo!(), @@ -216,15 +216,15 @@ impl ArithmeticCircuit { AGateType::AGt => todo!(), AGateType::ALEq => todo!(), AGateType::ALt => { - println!("{} = {} < {}", o_name, lh_name, rh_name); + // println!("{} = {} < {}", o_name, lh_name, rh_name); }, AGateType::AMul => { - println!("{} = {} * {}", o_name, lh_name, rh_name); + // println!("{} = {} * {}", o_name, lh_name, rh_name); }, AGateType::ANeq => todo!(), AGateType::ANone => todo!(), AGateType::ASub => { - println!("{} = {} - {}", o_name, lh_name, rh_name); + // println!("{} = {} - {}", o_name, lh_name, rh_name); }, }; @@ -290,7 +290,7 @@ impl ArithmeticCircuit { .retain(|node| node.id != node_a.id && node.id != node_b.id); self.nodes.push(merged_node); - println!("{} = {}", a_name, b_name); + // println!("{} = {}", a_name, b_name); Ok(()) }