From b139bdaf50d5c9b5336f659f5ab5e310fb3f917e Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Tue, 27 Feb 2024 23:06:39 +0700 Subject: [PATCH 01/10] print add gate/connection info --- src/circuit.rs | 27 ++++++++++++++++++++++++++- src/process.rs | 17 +++++++++++++---- src/runtime.rs | 6 +++++- 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 768b179..3a670e9 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -170,6 +170,9 @@ impl ArithmeticCircuit { lhs_id: u32, rhs_id: u32, output_id: u32, + lh_name: String, + rh_name: String, + o_name: String ) -> Result<(), CircuitError> { // Check that the inputs are declared if !self.contains_var(&lhs_id) @@ -179,6 +182,26 @@ impl ArithmeticCircuit { return Err(CircuitError::VariableNotDeclared); } + match gate_type { + AGateType::AAdd => { + println!("DEBUG =================================================== {:?} = {:?} + {:?}", o_name, lh_name, rh_name); + }, + AGateType::ADiv => todo!(), + AGateType::AEq => todo!(), + AGateType::AGEq => todo!(), + AGateType::AGt => todo!(), + AGateType::ALEq => todo!(), + AGateType::ALt => todo!(), + AGateType::AMul => { + println!("DEBUG =================================================== {:?} = {:?} * {:?}", o_name, lh_name, rh_name); + }, + AGateType::ANeq => todo!(), + AGateType::ANone => todo!(), + AGateType::ASub => { + println!("DEBUG =================================================== {:?} = {:?} - {:?}", o_name, lh_name, rh_name); + }, + }; + // Get the signal nodes let lhs_node = self.get_signal_node(lhs_id)?; let rhs_node = self.get_signal_node(rhs_id)?; @@ -200,7 +223,7 @@ impl ArithmeticCircuit { /// Creates a connection between two signals in the circuit. /// This is done by finding the nodes that contain the signals and merging them. - pub fn add_connection(&mut self, a: u32, b: u32) -> Result<(), CircuitError> { + pub fn add_connection(&mut self, a: u32, b: u32, a_name: String, b_name: String) -> Result<(), CircuitError> { // Check that the endpoints are declared if !self.contains_var(&a) || !self.contains_var(&b) { return Err(CircuitError::VariableNotDeclared); @@ -241,6 +264,8 @@ impl ArithmeticCircuit { .retain(|node| node.id != node_a.id && node.id != node_b.id); self.nodes.push(merged_node); + println!("DEBUG =================================================== {:?} = {:?}", a_name, b_name); + Ok(()) } diff --git a/src/process.rs b/src/process.rs index 587e2ab..4974856 100644 --- a/src/process.rs +++ b/src/process.rs @@ -160,8 +160,9 @@ pub fn process_statement( // Connect the generated gate output to the given signal let given_output_id = ctx.get_signal_id(&lh_access)?; let gate_output_id = get_signal_for_access(ac, ctx, &rh_access)?; - - ac.add_connection(gate_output_id, given_output_id)?; + let a_name = lh_access.access_str(); + let b_name = rh_access.access_str(); + ac.add_connection(gate_output_id, given_output_id, a_name, b_name)?; } DataType::Variable => { // Assign the evaluated right-hand side to the left-hand side @@ -179,7 +180,10 @@ pub fn process_statement( let component_signal = ctx.get_component_signal_id(&lh_access)?; let assigned_signal = get_signal_for_access(ac, ctx, &rh_access)?; - ac.add_connection(assigned_signal, component_signal)?; + let a_name = lh_access.access_str(); + let b_name = rh_access.access_str(); + + ac.add_connection(assigned_signal, component_signal, a_name, b_name)?; } _ => return Err(ProgramError::OperationNotSupported), }, @@ -444,7 +448,12 @@ fn handle_infix_op( // Add output signal and gate to the circuit ac.add_signal(output_id)?; - ac.add_gate(gate_type, lhs_id, rhs_id, output_id)?; + + let lh_name = lhe_access.access_str(); + let rh_name = rhe_access.access_str(); + let o_name = output_signal.access_str(); + + ac.add_gate(gate_type, lhs_id, rhs_id, output_id, lh_name, rh_name, o_name)?; Ok(output_signal) } diff --git a/src/runtime.rs b/src/runtime.rs index 43cab8b..0f07acd 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -5,7 +5,7 @@ use crate::program::ProgramError; use circom_program_structure::ast::VariableType; use rand::{thread_rng, Rng}; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::{collections::{HashMap, HashSet, VecDeque}, fmt::format}; use thiserror::Error; pub const RETURN_VAR: &str = "function_return_value"; @@ -487,6 +487,10 @@ impl DataAccess { pub fn get_access(&self) -> &Vec { &self.access } + + pub fn access_str(&self) -> String { + format!("{:?} {:?}", self.get_name(), self.get_access()) + } } /// Processes an access to a component's signal. From 0f2e5cf9f8e26ffdfaf96fa4e55d66a08f843214 Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Wed, 28 Feb 2024 11:20:12 +0700 Subject: [PATCH 02/10] Update circuit.rs --- src/circuit.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 3a670e9..206b9f8 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -184,7 +184,7 @@ impl ArithmeticCircuit { match gate_type { AGateType::AAdd => { - println!("DEBUG =================================================== {:?} = {:?} + {:?}", o_name, lh_name, rh_name); + println!("{:?} = {:?} + {:?}", o_name, lh_name, rh_name); }, AGateType::ADiv => todo!(), AGateType::AEq => todo!(), @@ -193,12 +193,12 @@ impl ArithmeticCircuit { AGateType::ALEq => todo!(), AGateType::ALt => todo!(), AGateType::AMul => { - println!("DEBUG =================================================== {:?} = {:?} * {:?}", o_name, lh_name, rh_name); + println!("{:?} = {:?} * {:?}", o_name, lh_name, rh_name); }, AGateType::ANeq => todo!(), AGateType::ANone => todo!(), AGateType::ASub => { - println!("DEBUG =================================================== {:?} = {:?} - {:?}", o_name, lh_name, rh_name); + println!("{:?} = {:?} - {:?}", o_name, lh_name, rh_name); }, }; @@ -264,7 +264,7 @@ impl ArithmeticCircuit { .retain(|node| node.id != node_a.id && node.id != node_b.id); self.nodes.push(merged_node); - println!("DEBUG =================================================== {:?} = {:?}", a_name, b_name); + println!("{:?} = {:?}", a_name, b_name); Ok(()) } From 149390055508e5cbb11e1496f9dc4b2407fe374c Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Wed, 28 Feb 2024 11:24:57 +0700 Subject: [PATCH 03/10] cleanup --- src/circom/execution.rs | 2 +- src/process.rs | 70 ++--------------------------------------- src/runtime.rs | 2 +- 3 files changed, 4 insertions(+), 70 deletions(-) diff --git a/src/circom/execution.rs b/src/circom/execution.rs index 2f9f472..00076b5 100644 --- a/src/circom/execution.rs +++ b/src/circom/execution.rs @@ -28,7 +28,7 @@ pub fn execute_project( let build_config = BuildConfig { no_rounds: config.no_rounds, flag_json_sub: config.json_substitution_flag, - json_substitutions: config.json_substitutions, + // json_substitutions: config.json_substitutions, flag_s: config.flag_s, flag_f: config.flag_f, flag_p: config.flag_p, diff --git a/src/process.rs b/src/process.rs index 4974856..d581570 100644 --- a/src/process.rs +++ b/src/process.rs @@ -204,26 +204,7 @@ pub fn process_statement( Ok(()) } - Statement::MultSubstitution { meta, lhe, op, rhe } => { - println!("Statement not implemented: MultSubstitution"); - Ok(()) - } - Statement::UnderscoreSubstitution { meta, op, rhe } => { - println!("Statement not implemented: UnderscoreSubstitution"); - Ok(()) - } - Statement::ConstraintEquality { meta, lhe, rhe } => { - println!("Statement not implemented: ConstraintEquality"); - Ok(()) - } - Statement::LogCall { meta, args } => { - println!("Statement not implemented: LogCall"); - Ok(()) - } - Statement::Assert { meta, arg } => { - println!("Statement not implemented: Assert"); - Ok(()) - } + _ => todo!() } } @@ -254,54 +235,7 @@ pub fn process_expression( Expression::Variable { name, access, .. } => { build_access(ac, runtime, program_archive, name, access) } - Expression::PrefixOp { - meta, - prefix_op, - rhe, - } => { - println!("Expression not implemented:PrefixOp"); - Ok(DataAccess::new("", vec![])) - } - Expression::InlineSwitchOp { - meta, - cond, - if_true, - if_false, - } => { - println!("Expression not implemented:InlineSwitchOp"); - Ok(DataAccess::new("", vec![])) - } - Expression::ParallelOp { meta, rhe } => { - println!("Expression not implemented:ParallelOp"); - Ok(DataAccess::new("", vec![])) - } - Expression::AnonymousComp { - meta, - id, - is_parallel, - params, - signals, - names, - } => { - println!("Expression not implemented:AnonymousComp"); - Ok(DataAccess::new("", vec![])) - } - Expression::ArrayInLine { meta, values } => { - println!("Expression not implemented:ArrayInLine"); - Ok(DataAccess::new("", vec![])) - } - Expression::Tuple { meta, values } => { - println!("Expression not implemented:Tuple"); - Ok(DataAccess::new("", vec![])) - } - Expression::UniformArray { - meta, - value, - dimension, - } => { - println!("Expression not implemented: UniformArray"); - Ok(DataAccess::new("", vec![])) - } + _ => todo!() } } diff --git a/src/runtime.rs b/src/runtime.rs index 0f07acd..583bf34 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -5,7 +5,7 @@ use crate::program::ProgramError; use circom_program_structure::ast::VariableType; use rand::{thread_rng, Rng}; -use std::{collections::{HashMap, HashSet, VecDeque}, fmt::format}; +use std::collections::{HashMap, HashSet, VecDeque}; use thiserror::Error; pub const RETURN_VAR: &str = "function_return_value"; From 77c7488c3d60b4f14df09b4a2427178f7c33bdcc Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Wed, 28 Feb 2024 11:40:32 +0700 Subject: [PATCH 04/10] clean printout --- src/circuit.rs | 8 ++++---- src/runtime.rs | 16 ++++++++++++++-- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 206b9f8..5e21eb6 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -184,7 +184,7 @@ impl ArithmeticCircuit { match gate_type { AGateType::AAdd => { - println!("{:?} = {:?} + {:?}", o_name, lh_name, rh_name); + println!("{} = {} + {}", o_name, lh_name, rh_name); }, AGateType::ADiv => todo!(), AGateType::AEq => todo!(), @@ -193,12 +193,12 @@ impl ArithmeticCircuit { AGateType::ALEq => todo!(), AGateType::ALt => todo!(), AGateType::AMul => { - println!("{:?} = {:?} * {:?}", o_name, lh_name, rh_name); + println!("{} = {} * {}", o_name, lh_name, rh_name); }, AGateType::ANeq => todo!(), AGateType::ANone => todo!(), AGateType::ASub => { - println!("{:?} = {:?} - {:?}", o_name, lh_name, rh_name); + println!("{} = {} - {}", o_name, lh_name, rh_name); }, }; @@ -264,7 +264,7 @@ impl ArithmeticCircuit { .retain(|node| node.id != node_a.id && node.id != node_b.id); self.nodes.push(merged_node); - println!("{:?} = {:?}", a_name, b_name); + println!("{} = {}", a_name, b_name); Ok(()) } diff --git a/src/runtime.rs b/src/runtime.rs index 583bf34..4e9bbbd 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -5,7 +5,7 @@ use crate::program::ProgramError; use circom_program_structure::ast::VariableType; use rand::{thread_rng, Rng}; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::{collections::{HashMap, HashSet, VecDeque}, fmt::Write}; use thiserror::Error; pub const RETURN_VAR: &str = "function_return_value"; @@ -489,7 +489,19 @@ impl DataAccess { } pub fn access_str(&self) -> String { - format!("{:?} {:?}", self.get_name(), self.get_access()) + let mut ret = String::new(); + ret.write_str(self.get_name().as_str()); + for sub in self.get_access() { + match sub { + SubAccess::Array(index) => { + ret.write_str(format!("[{}]", index).as_str()); + } + SubAccess::Component(name) => { + ret.write_str(format!(".{}]", name).as_str()); + } + } + } + ret } } From 79802c8f395124a6b91a0a0faea928550aabd35d Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Wed, 28 Feb 2024 11:52:35 +0700 Subject: [PATCH 05/10] trace println --- src/assets/circuit.circom | 92 +++++++++++++++++++-------------------- src/process.rs | 2 + src/program.rs | 3 +- src/runtime.rs | 6 +-- 4 files changed, 51 insertions(+), 52 deletions(-) diff --git a/src/assets/circuit.circom b/src/assets/circuit.circom index 4dfb8a9..ea5113d 100644 --- a/src/assets/circuit.circom +++ b/src/assets/circuit.circom @@ -103,65 +103,61 @@ template div_relu(k) { } template network() { - // var in_len = 2; - // var out_len = 3; - signal input in[1]; - signal output out[2]; - - component l0 = fc(1, 2); - signal input w0[2][1]; - signal input b0[2]; - for (var i = 0; i < 2; i++) { - for (var j = 0; j < 1; j++) { + // var in_len = 3; + // var out_len = 5; + signal input in[3]; + signal output out[5]; + + component l0 = fc(3, 5); + signal input w0[5][3]; + signal input b0[5]; + for (var i = 0; i < 5; i++) { + for (var j = 0; j < 3; j++) { l0.weights[i][j] <== w0[i][j]; } l0.biases[i] <== b0[i]; } // l0.weights <== w0; // l0.biases <== b0; - for (var k = 0; k < 1; k++) { + for (var k = 0; k < 3; k++) { l0.in[k] <== in[k]; } - for (var k = 0; k < 2; k++) { - out[k] <== l0.out[k]; + component l1 = fc(5, 7); + signal input w1[7][5]; + signal input b1[7]; + for (var i = 0; i < 7; i++) { + for (var j = 0; j < 5; j++) { + l1.weights[i][j] <== w1[i][j]; + } + l1.biases[i] <== b1[i]; + } + // l1.weights <== w1; + // l1.biases <== b1; + for (var k = 0; k < 5; k++) { + l1.in[k] <== l0.out[k]; + } + // l1.in <== l0.out; + + component l2 = fc_no_relu(7, 5); + signal input w2[5][7]; + signal input b2[5]; + for (var i = 0; i < 5; i++) { + for (var j = 0; j < 7; j++) { + l2.weights[i][j] <== w2[i][j]; + } + l2.biases[i] <== b2[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < 7; k++) { + l2.in[k] <== l1.out[k]; } + // l2.in <== l1.out; - // component l1 = fc(3, 4); - // signal input w1[4][3]; - // signal input b1[4]; - // for (var i = 0; i < 4; i++) { - // for (var j = 0; j < 3; j++) { - // l1.weights[i][j] <== w1[i][j]; - // } - // l1.biases[i] <== b1[i]; - // } - // // l1.weights <== w1; - // // l1.biases <== b1; - // for (var k = 0; k < 3; k++) { - // l1.in[k] <== l0.out[k]; - // } - // // l1.in <== l0.out; - - // component l2 = fc_no_relu(4, 3); - // signal input w2[3][4]; - // signal input b2[3]; - // for (var i = 0; i < 3; i++) { - // for (var j = 0; j < 4; j++) { - // l2.weights[i][j] <== w2[i][j]; - // } - // l2.biases[i] <== b2[i]; - // } - // // l2.weights <== w2; - // // l2.biases <== b2; - // for (var k = 0; k < 4; k++) { - // l2.in[k] <== l1.out[k]; - // } - // // l2.in <== l1.out; - - // for (var k = 0; k < 3; k++) { - // out[k] <== l2.out[k]; - // } + for (var k = 0; k < 5; k++) { + out[k] <== l2.out[k]; + } // out <== l2.out; } diff --git a/src/process.rs b/src/process.rs index d581570..236a973 100644 --- a/src/process.rs +++ b/src/process.rs @@ -291,7 +291,9 @@ fn handle_call( } // Process the function/template body + println!("================================ CALL {}", id); process_statements(ac, runtime, program_archive, &body)?; + println!("============================ END CALL {}", id); // Get return values let mut function_return: Option = None; diff --git a/src/program.rs b/src/program.rs index 9595b0e..8809d06 100644 --- a/src/program.rs +++ b/src/program.rs @@ -22,8 +22,9 @@ pub fn build_circuit(input: &Input) -> Result { if let Expression::Call { id, .. } = program_archive.get_main_expression() { let statements = program_archive.get_template_data(id).get_body_as_vec(); - + println!("================================ MAIN"); process_statements(&mut circuit, &mut runtime, &program_archive, statements)?; + println!("============================ END MAIN"); } Ok(circuit) diff --git a/src/runtime.rs b/src/runtime.rs index 4e9bbbd..1421371 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -490,14 +490,14 @@ impl DataAccess { pub fn access_str(&self) -> String { let mut ret = String::new(); - ret.write_str(self.get_name().as_str()); + ret.write_str(self.get_name().as_str()).ok(); for sub in self.get_access() { match sub { SubAccess::Array(index) => { - ret.write_str(format!("[{}]", index).as_str()); + ret.write_str(format!("[{}]", index).as_str()).ok(); } SubAccess::Component(name) => { - ret.write_str(format!(".{}]", name).as_str()); + ret.write_str(format!(".{}", name).as_str()).ok(); } } } From b66fa92f57addd8538f26f12c68744e6867adda4 Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Wed, 28 Feb 2024 12:25:13 +0700 Subject: [PATCH 06/10] distinguish call and main vars --- src/circuit.rs | 22 +++++++++++++++------- src/process.rs | 30 +++++++++++++++--------------- src/runtime.rs | 21 ++++++++++++++------- 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 5e21eb6..29018c0 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -50,20 +50,23 @@ impl From<&ExpressionInfixOpcode> for AGateType { pub struct Node { id: u32, signals: Vec, + names: Vec, } impl Node { /// Creates a new node. - pub fn new(signal_id: u32) -> Self { + pub fn new(signal_id: u32, signal_name: String) -> Self { Self { id: generate_u32(), signals: vec![signal_id], + names: vec![signal_name] } } /// Adds a set of signals to the node. - pub fn add_signals(&mut self, signals: Vec) { + pub fn add_signals(&mut self, signals: Vec, names: Vec) { self.signals.extend(signals); + self.names.extend(names); } /// Gets the signals of the node. @@ -71,15 +74,20 @@ impl Node { self.signals.clone() } + pub fn get_signals_names(&self) -> Vec { + self.names.clone() + } + /// Merges the signals of the node with another node, creating a new node. pub fn merge(&self, merge_node: &Node) -> Self { let mut new_node = Node { id: generate_u32(), signals: Vec::new(), + names: Vec::new(), }; - new_node.add_signals(self.get_signals()); - new_node.add_signals(merge_node.get_signals()); + new_node.add_signals(self.get_signals(), self.get_signals_names()); + new_node.add_signals(merge_node.get_signals(), merge_node.get_signals_names()); new_node } @@ -132,7 +140,7 @@ impl ArithmeticCircuit { } /// Adds a new signal variable to the circuit. - pub fn add_signal(&mut self, id: u32) -> Result<(), CircuitError> { + pub fn add_signal(&mut self, id: u32, name: String) -> Result<(), CircuitError> { // Check that the variable isn't already declared if self.contains_var(&id) { return Err(CircuitError::CircuitVariableAlreadyDeclared); @@ -140,7 +148,7 @@ impl ArithmeticCircuit { self.vars.insert(id, None); // Create a new node for the signal - let node = Node::new(id); + let node = Node::new(id, name); debug!("New {:?}", node); self.nodes.push(node); @@ -156,7 +164,7 @@ impl ArithmeticCircuit { self.vars.insert(value, Some(value)); // Create a new node for the constant - let node = Node::new(value); + let node = Node::new(value, format!("{}", value)); debug!("New {:?}", node); self.nodes.push(node); diff --git a/src/process.rs b/src/process.rs index 236a973..7bce702 100644 --- a/src/process.rs +++ b/src/process.rs @@ -75,7 +75,7 @@ pub fn process_statement( if dimensions.is_empty() { let signal_id = ctx.get_signal_id(&signal_access)?; - ac.add_signal(signal_id)?; + ac.add_signal(signal_id, signal_access.access_str(ctx.get_ctx_name()))?; } else { let mut indices: Vec = vec![0; dimensions.len()]; @@ -83,7 +83,7 @@ pub fn process_statement( // Set access and get signal id for the current indices signal_access.set_access(u32_to_access(&indices)); let signal_id = ctx.get_signal_id(&signal_access)?; - ac.add_signal(signal_id)?; + ac.add_signal(signal_id, signal_access.access_str(ctx.get_ctx_name()))?; // Increment indices if !increment_indices(&mut indices, &dimensions)? { @@ -96,7 +96,7 @@ pub fn process_statement( Ok(()) } Statement::While { cond, stmt, .. } => { - runtime.push_context(true)?; + runtime.push_context(true, format!("WHILE_PRE"))?; loop { let access = process_expression(ac, runtime, program_archive, cond)?; let result = runtime @@ -108,7 +108,7 @@ pub fn process_statement( break; } - runtime.push_context(true)?; + runtime.push_context(true, format!("WHILE_EXE"))?; process_statement(ac, runtime, program_archive, stmt)?; runtime.pop_context(true)?; } @@ -130,7 +130,7 @@ pub fn process_statement( if result == 0 { if let Some(else_statement) = else_case { - runtime.push_context(true)?; + runtime.push_context(true, format!("IF_TRUE"))?; process_statement(ac, runtime, program_archive, else_statement)?; runtime.pop_context(true)?; Ok(()) @@ -138,7 +138,7 @@ pub fn process_statement( Ok(()) } } else { - runtime.push_context(true)?; + runtime.push_context(true, format!("IF_FALSE"))?; process_statement(ac, runtime, program_archive, if_case)?; runtime.pop_context(true)?; Ok(()) @@ -160,8 +160,8 @@ pub fn process_statement( // Connect the generated gate output to the given signal let given_output_id = ctx.get_signal_id(&lh_access)?; let gate_output_id = get_signal_for_access(ac, ctx, &rh_access)?; - let a_name = lh_access.access_str(); - let b_name = rh_access.access_str(); + let a_name = lh_access.access_str(ctx.get_ctx_name()); + let b_name = rh_access.access_str(ctx.get_ctx_name()); ac.add_connection(gate_output_id, given_output_id, a_name, b_name)?; } DataType::Variable => { @@ -180,8 +180,8 @@ pub fn process_statement( let component_signal = ctx.get_component_signal_id(&lh_access)?; let assigned_signal = get_signal_for_access(ac, ctx, &rh_access)?; - let a_name = lh_access.access_str(); - let b_name = rh_access.access_str(); + let a_name = lh_access.access_str(ctx.get_ctx_name()); + let b_name = rh_access.access_str(ctx.get_ctx_name()); ac.add_connection(assigned_signal, component_signal, a_name, b_name)?; } @@ -278,7 +278,7 @@ fn handle_call( .collect::, ProgramError>>()?; // Create a new execution context - runtime.push_context(false)?; + runtime.push_context(false, format!("CALL_{}", id))?; // Set arguments in the new context for (arg_name, &arg_value) in arg_names.iter().zip(&arg_values) { @@ -383,11 +383,11 @@ fn handle_infix_op( let output_id = ctx.get_signal_id(&output_signal)?; // Add output signal and gate to the circuit - ac.add_signal(output_id)?; + ac.add_signal(output_id, output_signal.access_str(ctx.get_ctx_name()))?; - let lh_name = lhe_access.access_str(); - let rh_name = rhe_access.access_str(); - let o_name = output_signal.access_str(); + let lh_name = lhe_access.access_str(ctx.get_ctx_name()); + let rh_name = rhe_access.access_str(ctx.get_ctx_name()); + let o_name = output_signal.access_str(ctx.get_ctx_name()); ac.add_gate(gate_type, lhs_id, rhs_id, output_id, lh_name, rh_name, o_name)?; diff --git a/src/runtime.rs b/src/runtime.rs index 1421371..c8952dc 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -61,19 +61,19 @@ impl Runtime { /// Creates an empty runtime with no contexts. pub fn new() -> Self { Self { - contexts: VecDeque::from([Context::new()]), + contexts: VecDeque::from([Context::default()]), } } /// Adds a new context onto the stack, optionally inheriting from the current context. - pub fn push_context(&mut self, inherit: bool) -> Result<(), RuntimeError> { + pub fn push_context(&mut self, inherit: bool, id: String) -> Result<(), RuntimeError> { let new_context = if inherit { match self.contexts.front() { Some(parent_context) => Context::new_with_inheritance(parent_context), None => return Err(RuntimeError::NoContextToInheritFrom), } } else { - Context::new() + Context::new(id) }; self.contexts.push_front(new_context); Ok(()) @@ -114,6 +114,7 @@ impl Runtime { /// Handles a specific scope value tracking. #[derive(Clone)] pub struct Context { + ctx_name: String, names: HashSet, variables: HashMap, signals: HashMap, @@ -122,14 +123,15 @@ pub struct Context { impl Default for Context { fn default() -> Self { - Self::new() + Self::new(format!("0")) } } impl Context { /// Constructs a new Context. - pub fn new() -> Self { + pub fn new(ctx_name: String) -> Self { Self { + ctx_name, names: HashSet::new(), variables: HashMap::new(), signals: HashMap::new(), @@ -140,6 +142,7 @@ impl Context { /// Returns a contexts that inherits from the current context. pub fn new_with_inheritance(&self) -> Self { Self { + ctx_name: self.ctx_name.clone(), names: self.names.clone(), variables: self.variables.clone(), signals: self.signals.clone(), @@ -147,6 +150,10 @@ impl Context { } } + pub fn get_ctx_name(&self) -> String { + self.ctx_name.clone() + } + /// Merges changes from the given context into this context. /// Signals are not merged, as they are read-only. pub fn merge(&mut self, child: &Context) -> Result<(), RuntimeError> { @@ -488,8 +495,8 @@ impl DataAccess { &self.access } - pub fn access_str(&self) -> String { - let mut ret = String::new(); + pub fn access_str(&self, ctx_name: String) -> String { + let mut ret = String::from(format!("{}.", ctx_name)); ret.write_str(self.get_name().as_str()).ok(); for sub in self.get_access() { match sub { From 8679b721d95d928889adbff81a486936d7b21f94 Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Wed, 28 Feb 2024 13:23:24 +0700 Subject: [PATCH 07/10] cleanup --- src/assets/circuit.circom | 28 +++++++++++++++++++++++++++- src/circuit.rs | 8 +++++--- src/process.rs | 11 ++++++++++- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/assets/circuit.circom b/src/assets/circuit.circom index ea5113d..47b26af 100644 --- a/src/assets/circuit.circom +++ b/src/assets/circuit.circom @@ -82,6 +82,7 @@ template ShiftRight(k) { template Sign() { signal input in; signal output sign; + sign <== in < 0; } template div_relu(k) { @@ -102,6 +103,30 @@ template div_relu(k) { out <== switcher.outL; } +template test() { + signal input in[2]; + signal output out[3]; + + component l0 = fc(2, 3); + signal input w0[3][2]; + signal input b0[3]; + for (var i = 0; i < 3; i++) { + for (var j = 0; j < 2; j++) { + l0.weights[i][j] <== w0[i][j]; + } + l0.biases[i] <== b0[i]; + } + // l0.weights <== w0; + // l0.biases <== b0; + for (var k = 0; k < 2; k++) { + l0.in[k] <== in[k]; + } + + for (var k = 0; k < 3; k++) { + out[k] <== l0.out[k]; + } +} + template network() { // var in_len = 3; // var out_len = 5; @@ -161,4 +186,5 @@ template network() { // out <== l2.out; } -component main = network(); \ No newline at end of file +component main = test(); +// component main = network(); \ No newline at end of file diff --git a/src/circuit.rs b/src/circuit.rs index 29018c0..9b7bfb1 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -156,7 +156,7 @@ impl ArithmeticCircuit { } /// Adds a new constant variable to the circuit. - pub fn add_const(&mut self, value: u32) -> Result<(), CircuitError> { + pub fn add_const(&mut self, value: u32, name: String) -> Result<(), CircuitError> { // Ignore if the constant is already declared if self.contains_var(&value) { return Ok(()); @@ -164,7 +164,7 @@ impl ArithmeticCircuit { self.vars.insert(value, Some(value)); // Create a new node for the constant - let node = Node::new(value, format!("{}", value)); + let node = Node::new(value, name); debug!("New {:?}", node); self.nodes.push(node); @@ -199,7 +199,9 @@ impl ArithmeticCircuit { AGateType::AGEq => todo!(), AGateType::AGt => todo!(), AGateType::ALEq => todo!(), - AGateType::ALt => todo!(), + AGateType::ALt => { + println!("{} = {} < {}", o_name, lh_name, rh_name); + }, AGateType::AMul => { println!("{} = {} * {}", o_name, lh_name, rh_name); }, diff --git a/src/process.rs b/src/process.rs index 7bce702..c4ee0d0 100644 --- a/src/process.rs +++ b/src/process.rs @@ -373,6 +373,15 @@ fn handle_infix_op( return Ok(item_access); } + match op { + ExpressionInfixOpcode::Lesser => { + println!("DEBUG ALt"); + }, + _ => { + + } + } + // Handle cases where one or both inputs are signals let lhs_id = get_signal_for_access(ac, ctx, &lhe_access)?; let rhs_id = get_signal_for_access(ac, ctx, &rhe_access)?; @@ -408,7 +417,7 @@ fn get_signal_for_access( let value = ctx .get_variable_value(access)? .ok_or(ProgramError::EmptyDataItem)?; - ac.add_const(value)?; + ac.add_const(value, access.access_str(ctx.get_ctx_name()))?; Ok(value) } DataType::Component => Ok(ctx.get_component_signal_id(access)?), From d1b7bd917347ad0464a9f31a235a0c31de0c0359 Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Wed, 28 Feb 2024 13:28:31 +0700 Subject: [PATCH 08/10] add ip test --- src/assets/circuit.circom | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/assets/circuit.circom b/src/assets/circuit.circom index 47b26af..22e40ec 100644 --- a/src/assets/circuit.circom +++ b/src/assets/circuit.circom @@ -186,5 +186,25 @@ template network() { // out <== l2.out; } -component main = test(); -// component main = network(); \ No newline at end of file +// component main = test(); +// component main = network(); + +template InnerProd () { + + // Declaration of signals + signal input input_A[3]; + signal input input_B[3]; + signal output ip; + + signal sum[3]; + + sum[0] <== input_A[0]*input_B[0]; + + for (var i = 1; i < 3; i++) { + sum[i] <== sum[i-1] + input_A[i] * input_B[i]; + } + + ip <== sum[2]; +} + +component main = InnerProd(); \ No newline at end of file From 3e778f250d14f5a29868eea4fb264658a41b1175 Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Wed, 28 Feb 2024 13:52:00 +0700 Subject: [PATCH 09/10] add is_const and const_value --- src/circuit.rs | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 9b7bfb1..c6f2e2f 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -51,18 +51,30 @@ pub struct Node { id: u32, signals: Vec, names: Vec, + is_const: bool, + const_value: u32, } impl Node { /// Creates a new node. - pub fn new(signal_id: u32, signal_name: String) -> Self { + pub fn new(signal_id: u32, signal_name: String, is_const: bool, const_value: u32) -> Self { Self { id: generate_u32(), signals: vec![signal_id], - names: vec![signal_name] + names: vec![signal_name], + is_const, + const_value, } } + pub fn is_const(&self) -> bool { + self.is_const + } + + pub fn const_value(&self) -> u32 { + self.const_value + } + /// Adds a set of signals to the node. pub fn add_signals(&mut self, signals: Vec, names: Vec) { self.signals.extend(signals); @@ -80,10 +92,14 @@ impl Node { /// Merges the signals of the node with another node, creating a new node. pub fn merge(&self, merge_node: &Node) -> Self { + let ic = self.is_const() | merge_node.is_const(); + let cv = self.const_value(); let mut new_node = Node { id: generate_u32(), signals: Vec::new(), names: Vec::new(), + is_const: ic, + const_value: cv, }; new_node.add_signals(self.get_signals(), self.get_signals_names()); @@ -148,7 +164,7 @@ impl ArithmeticCircuit { self.vars.insert(id, None); // Create a new node for the signal - let node = Node::new(id, name); + let node = Node::new(id, name, false, 0); debug!("New {:?}", node); self.nodes.push(node); @@ -164,7 +180,7 @@ impl ArithmeticCircuit { self.vars.insert(value, Some(value)); // Create a new node for the constant - let node = Node::new(value, name); + let node = Node::new(value, name, true, value); debug!("New {:?}", node); self.nodes.push(node); From c9e83e964c6fa4e0b96937ccbbeee16fc60f1587 Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Fri, 8 Mar 2024 09:59:18 +0700 Subject: [PATCH 10/10] fc 8 layers --- src/assets/circuit.circom | 202 +++++++++++++++++-------- src/assets/circuit_bak.circom | 274 ++++++++++++++++++++++++++++++++++ src/assets/fc.circom | 85 +++++++---- src/assets/fc_lite.circom | 199 ++++++++++++++++++++++++ src/circuit.rs | 10 +- 5 files changed, 674 insertions(+), 96 deletions(-) create mode 100644 src/assets/circuit_bak.circom create mode 100644 src/assets/fc_lite.circom diff --git a/src/assets/circuit.circom b/src/assets/circuit.circom index 22e40ec..87ae4de 100644 --- a/src/assets/circuit.circom +++ b/src/assets/circuit.circom @@ -103,108 +103,178 @@ template div_relu(k) { out <== switcher.outL; } -template test() { - signal input in[2]; - signal output out[3]; - - component l0 = fc(2, 3); - signal input w0[3][2]; - signal input b0[3]; - for (var i = 0; i < 3; i++) { - for (var j = 0; j < 2; j++) { - l0.weights[i][j] <== w0[i][j]; - } - l0.biases[i] <== b0[i]; - } - // l0.weights <== w0; - // l0.biases <== b0; - for (var k = 0; k < 2; k++) { - l0.in[k] <== in[k]; - } +template network() { + // Structure from python example + // self.fc1 = nn.Linear(2, 32) + // self.fc2 = nn.Linear(32, 64) + // self.fc3 = nn.Linear(64, 128) + // self.fc4 = nn.Linear(128, 4) - for (var k = 0; k < 3; k++) { - out[k] <== l0.out[k]; - } -} + var in_len = 2; + var out_len = 4; -template network() { - // var in_len = 3; - // var out_len = 5; - signal input in[3]; - signal output out[5]; - - component l0 = fc(3, 5); - signal input w0[5][3]; - signal input b0[5]; - for (var i = 0; i < 5; i++) { - for (var j = 0; j < 3; j++) { + var l0_w = in_len; + var l0_h = 32; + + var l1_w = l0_h; + var l1_h = 64; + + var l2_w = l1_h; + var l2_h = 128; + + var l3_w = l2_h; + var l3_h = 256; + + var l4_w = l3_h; + var l4_h = 512; + + var l5_w = l4_h; + var l5_h = 1024; + + var l6_w = l5_h; + var l6_h = 2048; + + var l7_w = l6_h; + var l7_h = out_len; + + signal input in[in_len]; + signal output out[out_len]; + + component l0 = fc(l0_w, l0_h); + signal input w0[l0_h][l0_w]; + signal input b0[l0_h]; + for (var i = 0; i < l0_h; i++) { + for (var j = 0; j < l0_w; j++) { l0.weights[i][j] <== w0[i][j]; } l0.biases[i] <== b0[i]; } // l0.weights <== w0; // l0.biases <== b0; - for (var k = 0; k < 3; k++) { + for (var k = 0; k < in_len; k++) { l0.in[k] <== in[k]; } - component l1 = fc(5, 7); - signal input w1[7][5]; - signal input b1[7]; - for (var i = 0; i < 7; i++) { - for (var j = 0; j < 5; j++) { + component l1 = fc(l1_w, l1_h); + signal input w1[l1_h][l1_w]; + signal input b1[l1_h]; + for (var i = 0; i < l1_h; i++) { + for (var j = 0; j < l1_w; j++) { l1.weights[i][j] <== w1[i][j]; } l1.biases[i] <== b1[i]; } // l1.weights <== w1; // l1.biases <== b1; - for (var k = 0; k < 5; k++) { + for (var k = 0; k < l0_h; k++) { l1.in[k] <== l0.out[k]; } // l1.in <== l0.out; - component l2 = fc_no_relu(7, 5); - signal input w2[5][7]; - signal input b2[5]; - for (var i = 0; i < 5; i++) { - for (var j = 0; j < 7; j++) { + component l2 = fc(l2_w, l2_h); + signal input w2[l2_h][l2_w]; + signal input b2[l2_h]; + for (var i = 0; i < l2_h; i++) { + for (var j = 0; j < l2_w; j++) { l2.weights[i][j] <== w2[i][j]; } l2.biases[i] <== b2[i]; } // l2.weights <== w2; // l2.biases <== b2; - for (var k = 0; k < 7; k++) { + for (var k = 0; k < l1_h; k++) { l2.in[k] <== l1.out[k]; } // l2.in <== l1.out; - for (var k = 0; k < 5; k++) { - out[k] <== l2.out[k]; + component l3 = fc(l3_w, l3_h); + signal input w3[l3_h][l3_w]; + signal input b3[l3_h]; + for (var i = 0; i < l3_h; i++) { + for (var j = 0; j < l3_w; j++) { + l3.weights[i][j] <== w3[i][j]; + } + l3.biases[i] <== b3[i]; } - // out <== l2.out; -} - -// component main = test(); -// component main = network(); - -template InnerProd () { + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l2_h; k++) { + l3.in[k] <== l2.out[k]; + } + // l2.in <== l1.out; - // Declaration of signals - signal input input_A[3]; - signal input input_B[3]; - signal output ip; + component l4 = fc(l4_w, l4_h); + signal input w4[l4_h][l4_w]; + signal input b4[l4_h]; + for (var i = 0; i < l4_h; i++) { + for (var j = 0; j < l4_w; j++) { + l4.weights[i][j] <== w4[i][j]; + } + l4.biases[i] <== b4[i]; + } + for (var k = 0; k < l4_h; k++) { + l4.in[k] <== l3.out[k]; + } - signal sum[3]; + component l5 = fc(l5_w, l5_h); + signal input w5[l5_h][l5_w]; + signal input b5[l5_h]; + for (var i = 0; i < l5_h; i++) { + for (var j = 0; j < l5_w; j++) { + l5.weights[i][j] <== w5[i][j]; + } + l5.biases[i] <== b5[i]; + } + for (var k = 0; k < l5_h; k++) { + l5.in[k] <== l4.out[k]; + } - sum[0] <== input_A[0]*input_B[0]; + component l6 = fc(l6_w, l6_h); + signal input w6[l6_h][l6_w]; + signal input b6[l6_h]; + for (var i = 0; i < l6_h; i++) { + for (var j = 0; j < l6_w; j++) { + l6.weights[i][j] <== w6[i][j]; + } + l6.biases[i] <== b6[i]; + } + for (var k = 0; k < l6_h; k++) { + l6.in[k] <== l5.out[k]; + } - for (var i = 1; i < 3; i++) { - sum[i] <== sum[i-1] + input_A[i] * input_B[i]; - } + component l7 = fc(l7_w, l7_h); + signal input w7[l7_h][l7_w]; + signal input b7[l7_h]; + for (var i = 0; i < l7_h; i++) { + for (var j = 0; j < l7_w; j++) { + l7.weights[i][j] <== w7[i][j]; + } + l7.biases[i] <== b7[i]; + } + for (var k = 0; k < l7_h; k++) { + l7.in[k] <== l6.out[k]; + } - ip <== sum[2]; + // component l8 = fc_no_relu(l8_w, l8_h); + // signal input w8[l8_h][l8_w]; + // signal input b8[l8_h]; + // for (var i = 0; i < l8_h; i++) { + // for (var j = 0; j < l8_w; j++) { + // l8.weights[i][j] <== w8[i][j]; + // } + // l8.biases[i] <== b8[i]; + // } + // // l3.weights <== w2; + // // l3.biases <== b2; + // for (var k = 0; k < l8_h; k++) { + // l8.in[k] <== l7.out[k]; + // } + // // l3.in <== l1.out; + + for (var k = 0; k < out_len; k++) { + out[k] <== l7.out[k]; + } + // out <== l2.out; } -component main = InnerProd(); \ No newline at end of file +component main = network(); \ No newline at end of file diff --git a/src/assets/circuit_bak.circom b/src/assets/circuit_bak.circom new file mode 100644 index 0000000..5a36677 --- /dev/null +++ b/src/assets/circuit_bak.circom @@ -0,0 +1,274 @@ +pragma circom 2.0.0; + +template Switcher() { + signal input sel; + signal input L; + signal input R; + signal output outL; + signal output outR; + + signal aux; + + aux <== (R-L)*sel; // We create aux in order to have only one multiplication + outL <== aux + L; + outR <== R - aux; +} + +template auction(n) { + signal input in[n]; + signal signs[n-1]; + signal maxidx[n-1]; + signal maxprice[n-1]; + signal output idx; + signal output price; + component sws[n-1]; + component sws2[n-1]; + + for (var i = 0; i < n-1; i++) { + signs[i] <== in[i] < in[i+1]; + sws[i] <== Switcher(); + sws[i].sel <== signs[i]; + sws[i].L <== in[i]; + sws[i].R <== in[i+1]; + maxprice[i] <== sws[i].outR; + sws2[i] <== Switcher(); + sws2[i].sel <== signs[i]; + sws2[i].L <== i; + sws2[i].R <== i+1; + maxidx[i] <== sws2[i].outR; + } + + idx <== maxidx[n-1]; + price <== maxprice[n-1]; +} + + +template fc (width, height) { + signal input in[width]; + signal input weights[height][width]; + signal input biases[height]; + signal output out[height]; + + component rows[height]; + + component relu[height]; + + for(var index = 0; index < height; index++) { + rows[index] = dot_product(width); + for(var index_input = 0; index_input < width; index_input++) { + rows[index].inputs[index_input] <== in[index_input]; + rows[index].weight_vector[index_input] <== weights[index][index_input]; + } + rows[index].bias <== biases[index]; + relu[index] = div_relu(12); + relu[index].in <== rows[index].out; + out[index] <== relu[index].out; + } +} + +template fc_no_relu (width, height) { + signal input in[width]; + signal input weights[height][width]; + signal input biases[height]; + signal output out[height]; + + component rows[height]; + + for(var index = 0; index < height; index++) { + rows[index] = dot_product(width); + for(var index_input = 0; index_input < width; index_input++) { + rows[index].inputs[index_input] <== in[index_input]; + rows[index].weight_vector[index_input] <== weights[index][index_input]; + } + rows[index].bias <== biases[index]; + out[index] <== rows[index].out; + } +} + +template dot_product (width) { + signal input inputs[width]; + signal input weight_vector[width]; + signal inter_accum[width]; + signal input bias; + signal output out; + + inter_accum[0] <== inputs[0]*weight_vector[0]; + // inter_accum[0]*0 === 0; + + for(var index = 1; index < width; index++) { + inter_accum[index] <== inputs[index]*weight_vector[index] + inter_accum[index-1]; + } + out <== inter_accum[width-1] + bias; +} + +template ShiftRight(k) { + signal input in; + signal output out; + out <== in; +} + +template Sign() { + signal input in; + signal output sign; + sign <== in < 0; +} + +template div_relu(k) { + signal input in; + signal output out; + component shiftRight = ShiftRight(k); + component sign = Sign(); + + shiftRight.in <== in; + sign.in <== shiftRight.out; + + component switcher = Switcher(); + switcher.sel <== sign.sign; + switcher.L <== shiftRight.out; + switcher.R <== 0; + //switcher.outR*0 === 0; + + out <== switcher.outL; +} + +template test() { + signal input in[2]; + signal output out[3]; + + component l0 = fc(2, 3); + signal input w0[3][2]; + signal input b0[3]; + for (var i = 0; i < 3; i++) { + for (var j = 0; j < 2; j++) { + l0.weights[i][j] <== w0[i][j]; + } + l0.biases[i] <== b0[i]; + } + // l0.weights <== w0; + // l0.biases <== b0; + for (var k = 0; k < 2; k++) { + l0.in[k] <== in[k]; + } + + for (var k = 0; k < 3; k++) { + out[k] <== l0.out[k]; + } +} + +template network() { + // Structure from python example + // self.fc1 = nn.Linear(2, 32) + // self.fc2 = nn.Linear(32, 64) + // self.fc3 = nn.Linear(64, 128) + // self.fc4 = nn.Linear(128, 4) + + var in_len = 2; + var out_len = 4; + + var l0_w = in_len; + var l0_h = 32; + + var l1_w = l0_h; + var l1_h = 64; + + var l2_w = l1_h; + var l2_h = 128; + + var l3_w = l2_h; + var l3_h = out_len; + + signal input in[in_len]; + signal output out[out_len]; + + component l0 = fc(l0_w, l0_h); + signal input w0[l0_h][l0_w]; + signal input b0[l0_h]; + for (var i = 0; i < l0_h; i++) { + for (var j = 0; j < l0_w; j++) { + l0.weights[i][j] <== w0[i][j]; + } + l0.biases[i] <== b0[i]; + } + // l0.weights <== w0; + // l0.biases <== b0; + for (var k = 0; k < in_len; k++) { + l0.in[k] <== in[k]; + } + + component l1 = fc(l1_w, l1_h); + signal input w1[l1_h][l1_w]; + signal input b1[l1_h]; + for (var i = 0; i < l1_h; i++) { + for (var j = 0; j < l1_w; j++) { + l1.weights[i][j] <== w1[i][j]; + } + l1.biases[i] <== b1[i]; + } + // l1.weights <== w1; + // l1.biases <== b1; + for (var k = 0; k < l0_h; k++) { + l1.in[k] <== l0.out[k]; + } + // l1.in <== l0.out; + + component l2 = fc(l2_w, l2_h); + signal input w2[l2_h][l2_w]; + signal input b2[l2_h]; + for (var i = 0; i < l2_h; i++) { + for (var j = 0; j < l2_w; j++) { + l2.weights[i][j] <== w2[i][j]; + } + l2.biases[i] <== b2[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l1_h; k++) { + l2.in[k] <== l1.out[k]; + } + // l2.in <== l1.out; + + component l3 = fc_no_relu(l3_w, l3_h); + signal input w3[l3_h][l3_w]; + signal input b3[l3_h]; + for (var i = 0; i < l3_h; i++) { + for (var j = 0; j < l3_w; j++) { + l3.weights[i][j] <== w3[i][j]; + } + l3.biases[i] <== b3[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l2_h; k++) { + l3.in[k] <== l2.out[k]; + } + // l2.in <== l1.out; + + for (var k = 0; k < out_len; k++) { + out[k] <== l3.out[k]; + } + // out <== l2.out; +} + +// template InnerProd () { + +// // Declaration of signals +// signal input input_A[3]; +// signal input input_B[3]; +// signal output ip; + +// signal sum[3]; + +// sum[0] <== input_A[0]*input_B[0]; + +// for (var i = 1; i < 3; i++) { +// sum[i] <== sum[i-1] + input_A[i] * input_B[i]; +// } + +// ip <== sum[2]; +// } + +// component main = InnerProd(); + +// component main = test(); +component main = network(); + diff --git a/src/assets/fc.circom b/src/assets/fc.circom index ea5113d..1824a7b 100644 --- a/src/assets/fc.circom +++ b/src/assets/fc.circom @@ -103,60 +103,95 @@ template div_relu(k) { } template network() { - // var in_len = 3; - // var out_len = 5; - signal input in[3]; - signal output out[5]; - - component l0 = fc(3, 5); - signal input w0[5][3]; - signal input b0[5]; - for (var i = 0; i < 5; i++) { - for (var j = 0; j < 3; j++) { + // Structure from python example + // self.fc1 = nn.Linear(2, 32) + // self.fc2 = nn.Linear(32, 64) + // self.fc3 = nn.Linear(64, 128) + // self.fc4 = nn.Linear(128, 4) + + var in_len = 2; + var out_len = 4; + + var l0_w = in_len; + var l0_h = 32; + + var l1_w = l0_h; + var l1_h = 64; + + var l2_w = l1_h; + var l2_h = 128; + + var l3_w = l2_h; + var l3_h = out_len; + + signal input in[in_len]; + signal output out[out_len]; + + component l0 = fc(l0_w, l0_h); + signal input w0[l0_h][l0_w]; + signal input b0[l0_h]; + for (var i = 0; i < l0_h; i++) { + for (var j = 0; j < l0_w; j++) { l0.weights[i][j] <== w0[i][j]; } l0.biases[i] <== b0[i]; } // l0.weights <== w0; // l0.biases <== b0; - for (var k = 0; k < 3; k++) { + for (var k = 0; k < in_len; k++) { l0.in[k] <== in[k]; } - component l1 = fc(5, 7); - signal input w1[7][5]; - signal input b1[7]; - for (var i = 0; i < 7; i++) { - for (var j = 0; j < 5; j++) { + component l1 = fc(l1_w, l1_h); + signal input w1[l1_h][l1_w]; + signal input b1[l1_h]; + for (var i = 0; i < l1_h; i++) { + for (var j = 0; j < l1_w; j++) { l1.weights[i][j] <== w1[i][j]; } l1.biases[i] <== b1[i]; } // l1.weights <== w1; // l1.biases <== b1; - for (var k = 0; k < 5; k++) { + for (var k = 0; k < l0_h; k++) { l1.in[k] <== l0.out[k]; } // l1.in <== l0.out; - component l2 = fc_no_relu(7, 5); - signal input w2[5][7]; - signal input b2[5]; - for (var i = 0; i < 5; i++) { - for (var j = 0; j < 7; j++) { + component l2 = fc(l2_w, l2_h); + signal input w2[l2_h][l2_w]; + signal input b2[l2_h]; + for (var i = 0; i < l2_h; i++) { + for (var j = 0; j < l2_w; j++) { l2.weights[i][j] <== w2[i][j]; } l2.biases[i] <== b2[i]; } // l2.weights <== w2; // l2.biases <== b2; - for (var k = 0; k < 7; k++) { + for (var k = 0; k < l1_h; k++) { l2.in[k] <== l1.out[k]; } // l2.in <== l1.out; - for (var k = 0; k < 5; k++) { - out[k] <== l2.out[k]; + component l3 = fc_no_relu(l3_w, l3_h); + signal input w3[l3_h][l3_w]; + signal input b3[l3_h]; + for (var i = 0; i < l3_h; i++) { + for (var j = 0; j < l3_w; j++) { + l3.weights[i][j] <== w3[i][j]; + } + l3.biases[i] <== b3[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l2_h; k++) { + l3.in[k] <== l2.out[k]; + } + // l2.in <== l1.out; + + for (var k = 0; k < out_len; k++) { + out[k] <== l3.out[k]; } // out <== l2.out; } diff --git a/src/assets/fc_lite.circom b/src/assets/fc_lite.circom new file mode 100644 index 0000000..2ed26a4 --- /dev/null +++ b/src/assets/fc_lite.circom @@ -0,0 +1,199 @@ +pragma circom 2.0.0; + +template Switcher() { + signal input sel; + signal input L; + signal input R; + signal output outL; + signal output outR; + + signal aux; + + aux <== (R-L)*sel; // We create aux in order to have only one multiplication + outL <== aux + L; + outR <== R - aux; +} + + +template fc (width, height) { + signal input in[width]; + signal input weights[height][width]; + signal input biases[height]; + signal output out[height]; + + component rows[height]; + + component relu[height]; + + for(var index = 0; index < height; index++) { + rows[index] = dot_product(width); + for(var index_input = 0; index_input < width; index_input++) { + rows[index].inputs[index_input] <== in[index_input]; + rows[index].weight_vector[index_input] <== weights[index][index_input]; + } + rows[index].bias <== biases[index]; + relu[index] = div_relu(12); + relu[index].in <== rows[index].out; + out[index] <== relu[index].out; + } +} + +template fc_no_relu (width, height) { + signal input in[width]; + signal input weights[height][width]; + signal input biases[height]; + signal output out[height]; + + component rows[height]; + + for(var index = 0; index < height; index++) { + rows[index] = dot_product(width); + for(var index_input = 0; index_input < width; index_input++) { + rows[index].inputs[index_input] <== in[index_input]; + rows[index].weight_vector[index_input] <== weights[index][index_input]; + } + rows[index].bias <== biases[index]; + out[index] <== rows[index].out; + } +} + +template dot_product (width) { + signal input inputs[width]; + signal input weight_vector[width]; + signal inter_accum[width]; + signal input bias; + signal output out; + + inter_accum[0] <== inputs[0]*weight_vector[0]; + // inter_accum[0]*0 === 0; + + for(var index = 1; index < width; index++) { + inter_accum[index] <== inputs[index]*weight_vector[index] + inter_accum[index-1]; + } + out <== inter_accum[width-1] + bias; +} + +template ShiftRight(k) { + signal input in; + signal output out; + out <== in; +} + +template Sign() { + signal input in; + signal output sign; +} + +template div_relu(k) { + signal input in; + signal output out; + component shiftRight = ShiftRight(k); + component sign = Sign(); + + shiftRight.in <== in; + sign.in <== shiftRight.out; + + component switcher = Switcher(); + switcher.sel <== sign.sign; + switcher.L <== shiftRight.out; + switcher.R <== 0; + //switcher.outR*0 === 0; + + out <== switcher.outL; +} + +template network() { + // Structure from python example + // self.fc1 = nn.Linear(2, 32) + // self.fc2 = nn.Linear(32, 64) + // self.fc3 = nn.Linear(64, 128) + // self.fc4 = nn.Linear(128, 4) + + var in_len = 2; + var out_len = 4; + + var l0_w = in_len; + var l0_h = 5; + + var l1_w = l0_h; + var l1_h = 7; + + var l2_w = l1_h; + var l2_h = 11; + + var l3_w = l2_h; + var l3_h = out_len; + + signal input in[in_len]; + signal output out[out_len]; + + component l0 = fc(l0_w, l0_h); + signal input w0[l0_h][l0_w]; + signal input b0[l0_h]; + for (var i = 0; i < l0_h; i++) { + for (var j = 0; j < l0_w; j++) { + l0.weights[i][j] <== w0[i][j]; + } + l0.biases[i] <== b0[i]; + } + // l0.weights <== w0; + // l0.biases <== b0; + for (var k = 0; k < in_len; k++) { + l0.in[k] <== in[k]; + } + + component l1 = fc(l1_w, l1_h); + signal input w1[l1_h][l1_w]; + signal input b1[l1_h]; + for (var i = 0; i < l1_h; i++) { + for (var j = 0; j < l1_w; j++) { + l1.weights[i][j] <== w1[i][j]; + } + l1.biases[i] <== b1[i]; + } + // l1.weights <== w1; + // l1.biases <== b1; + for (var k = 0; k < l0_h; k++) { + l1.in[k] <== l0.out[k]; + } + // l1.in <== l0.out; + + component l2 = fc(l2_w, l2_h); + signal input w2[l2_h][l2_w]; + signal input b2[l2_h]; + for (var i = 0; i < l2_h; i++) { + for (var j = 0; j < l2_w; j++) { + l2.weights[i][j] <== w2[i][j]; + } + l2.biases[i] <== b2[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l1_h; k++) { + l2.in[k] <== l1.out[k]; + } + // l2.in <== l1.out; + + component l3 = fc_no_relu(l3_w, l3_h); + signal input w3[l3_h][l3_w]; + signal input b3[l3_h]; + for (var i = 0; i < l3_h; i++) { + for (var j = 0; j < l3_w; j++) { + l3.weights[i][j] <== w3[i][j]; + } + l3.biases[i] <== b3[i]; + } + // l2.weights <== w2; + // l2.biases <== b2; + for (var k = 0; k < l2_h; k++) { + l3.in[k] <== l2.out[k]; + } + // l2.in <== l1.out; + + for (var k = 0; k < out_len; k++) { + out[k] <== l3.out[k]; + } + // out <== l2.out; +} + +component main = network(); \ No newline at end of file diff --git a/src/circuit.rs b/src/circuit.rs index c6f2e2f..6b2297f 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -208,7 +208,7 @@ impl ArithmeticCircuit { match gate_type { AGateType::AAdd => { - println!("{} = {} + {}", o_name, lh_name, rh_name); + // println!("{} = {} + {}", o_name, lh_name, rh_name); }, AGateType::ADiv => todo!(), AGateType::AEq => todo!(), @@ -216,15 +216,15 @@ impl ArithmeticCircuit { AGateType::AGt => todo!(), AGateType::ALEq => todo!(), AGateType::ALt => { - println!("{} = {} < {}", o_name, lh_name, rh_name); + // println!("{} = {} < {}", o_name, lh_name, rh_name); }, AGateType::AMul => { - println!("{} = {} * {}", o_name, lh_name, rh_name); + // println!("{} = {} * {}", o_name, lh_name, rh_name); }, AGateType::ANeq => todo!(), AGateType::ANone => todo!(), AGateType::ASub => { - println!("{} = {} - {}", o_name, lh_name, rh_name); + // println!("{} = {} - {}", o_name, lh_name, rh_name); }, }; @@ -290,7 +290,7 @@ impl ArithmeticCircuit { .retain(|node| node.id != node_a.id && node.id != node_b.id); self.nodes.push(merged_node); - println!("{} = {}", a_name, b_name); + // println!("{} = {}", a_name, b_name); Ok(()) }