diff --git a/Cargo.toml b/Cargo.toml index 415b695..d3f1f5a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,4 +34,3 @@ mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", p [[bin]] name="circom" path="./src/main.rs" - diff --git a/src/circuit.rs b/src/circuit.rs index c6f2e2f..66c6619 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -45,65 +45,58 @@ impl From<&ExpressionInfixOpcode> for AGateType { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Signal { + id: u32, + name: String, + value: Option, +} + +impl Signal { + pub fn new(id: u32, name: String, value: Option) -> Self { + Self { id, name, value } + } + + pub fn is_const(&self) -> bool { + self.value.is_some() + } +} + /// Represents a node in the circuit, with an identifier and a set of signals that it is connected to. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Node { id: u32, - signals: Vec, - names: Vec, - is_const: bool, - const_value: u32, + signals: Vec, } impl Node { /// Creates a new node. - pub fn new(signal_id: u32, signal_name: String, is_const: bool, const_value: u32) -> Self { + pub fn new(signal: Signal) -> Self { Self { id: generate_u32(), - signals: vec![signal_id], - names: vec![signal_name], - is_const, - const_value, + signals: vec![signal], } } - pub fn is_const(&self) -> bool { - self.is_const - } - - pub fn const_value(&self) -> u32 { - self.const_value - } - /// Adds a set of signals to the node. - pub fn add_signals(&mut self, signals: Vec, names: Vec) { + pub fn add_signals(&mut self, signals: Vec) { self.signals.extend(signals); - self.names.extend(names); } /// Gets the signals of the node. - pub fn get_signals(&self) -> Vec { + pub fn get_signals(&self) -> Vec { self.signals.clone() } - pub fn get_signals_names(&self) -> Vec { - self.names.clone() - } - /// Merges the signals of the node with another node, creating a new node. pub fn merge(&self, merge_node: &Node) -> Self { - let ic = self.is_const() | merge_node.is_const(); - let cv = self.const_value(); let mut new_node = Node { id: generate_u32(), signals: Vec::new(), - names: Vec::new(), - is_const: ic, - const_value: cv, }; - new_node.add_signals(self.get_signals(), self.get_signals_names()); - new_node.add_signals(merge_node.get_signals(), merge_node.get_signals_names()); + new_node.add_signals(self.get_signals()); + new_node.add_signals(merge_node.get_signals()); new_node } @@ -140,7 +133,7 @@ impl ArithmeticGate { /// Represents an arithmetic circuit, with a set of variables and gates. #[derive(Debug, Default, Serialize, Deserialize)] pub struct ArithmeticCircuit { - vars: HashMap>, + signals: HashMap, nodes: Vec, gates: Vec, } @@ -149,7 +142,7 @@ impl ArithmeticCircuit { /// Creates a new arithmetic circuit. pub fn new() -> ArithmeticCircuit { ArithmeticCircuit { - vars: HashMap::new(), + signals: HashMap::new(), nodes: Vec::new(), gates: Vec::new(), } @@ -158,13 +151,18 @@ impl ArithmeticCircuit { /// Adds a new signal variable to the circuit. pub fn add_signal(&mut self, id: u32, name: String) -> Result<(), CircuitError> { // Check that the variable isn't already declared - if self.contains_var(&id) { + if self.is_signal_declared(&id) { return Err(CircuitError::CircuitVariableAlreadyDeclared); } - self.vars.insert(id, None); + + // Create a new Signal + let signal = Signal::new(id, name, None); + + // Store the signal data + self.signals.insert(signal.id, signal.clone()); // Create a new node for the signal - let node = Node::new(id, name, false, 0); + let node = Node::new(signal); debug!("New {:?}", node); self.nodes.push(node); @@ -174,13 +172,18 @@ impl ArithmeticCircuit { /// Adds a new constant variable to the circuit. pub fn add_const(&mut self, value: u32, name: String) -> Result<(), CircuitError> { // Ignore if the constant is already declared - if self.contains_var(&value) { + if self.is_signal_declared(&value) { return Ok(()); } - self.vars.insert(value, Some(value)); + + // Create a new constant Signal + let signal = Signal::new(value, name, Some(value)); + + // Store the signal data + self.signals.insert(signal.id, signal.clone()); // Create a new node for the constant - let node = Node::new(value, name, true, value); + let node = Node::new(signal); debug!("New {:?}", node); self.nodes.push(node); @@ -194,40 +197,15 @@ impl ArithmeticCircuit { lhs_id: u32, rhs_id: u32, output_id: u32, - lh_name: String, - rh_name: String, - o_name: String ) -> Result<(), CircuitError> { // Check that the inputs are declared - if !self.contains_var(&lhs_id) - || !self.contains_var(&rhs_id) - || !self.contains_var(&output_id) + if !self.is_signal_declared(&lhs_id) + || !self.is_signal_declared(&rhs_id) + || !self.is_signal_declared(&output_id) { return Err(CircuitError::VariableNotDeclared); } - match gate_type { - AGateType::AAdd => { - println!("{} = {} + {}", o_name, lh_name, rh_name); - }, - AGateType::ADiv => todo!(), - AGateType::AEq => todo!(), - AGateType::AGEq => todo!(), - AGateType::AGt => todo!(), - AGateType::ALEq => todo!(), - AGateType::ALt => { - println!("{} = {} < {}", o_name, lh_name, rh_name); - }, - AGateType::AMul => { - println!("{} = {} * {}", o_name, lh_name, rh_name); - }, - AGateType::ANeq => todo!(), - AGateType::ANone => todo!(), - AGateType::ASub => { - println!("{} = {} - {}", o_name, lh_name, rh_name); - }, - }; - // Get the signal nodes let lhs_node = self.get_signal_node(lhs_id)?; let rhs_node = self.get_signal_node(rhs_id)?; @@ -249,9 +227,9 @@ impl ArithmeticCircuit { /// Creates a connection between two signals in the circuit. /// This is done by finding the nodes that contain the signals and merging them. - pub fn add_connection(&mut self, a: u32, b: u32, a_name: String, b_name: String) -> Result<(), CircuitError> { + pub fn add_connection(&mut self, a: u32, b: u32) -> Result<(), CircuitError> { // Check that the endpoints are declared - if !self.contains_var(&a) || !self.contains_var(&b) { + if !self.is_signal_declared(&a) || !self.is_signal_declared(&b) { return Err(CircuitError::VariableNotDeclared); } @@ -290,25 +268,25 @@ impl ArithmeticCircuit { .retain(|node| node.id != node_a.id && node.id != node_b.id); self.nodes.push(merged_node); - println!("{} = {}", a_name, b_name); - Ok(()) } /// Returns the node containing the given signal. fn get_signal_node(&self, signal_id: u32) -> Result { for node in &self.nodes { - if node.signals.contains(&signal_id) { - return Ok(node.clone()); + for signal in &node.signals { + if signal.id == signal_id { + return Ok(node.clone()); + } } } Err(CircuitError::NodeNotFound) } - /// Checks if the variable exists - pub fn contains_var(&self, var: &u32) -> bool { - self.vars.contains_key(var) + /// Checks if the signal exists + pub fn is_signal_declared(&self, id: &u32) -> bool { + self.signals.contains_key(id) } /// Returns the number of gates in the circuit. diff --git a/src/process.rs b/src/process.rs index c4ee0d0..8294601 100644 --- a/src/process.rs +++ b/src/process.rs @@ -160,9 +160,7 @@ pub fn process_statement( // Connect the generated gate output to the given signal let given_output_id = ctx.get_signal_id(&lh_access)?; let gate_output_id = get_signal_for_access(ac, ctx, &rh_access)?; - let a_name = lh_access.access_str(ctx.get_ctx_name()); - let b_name = rh_access.access_str(ctx.get_ctx_name()); - ac.add_connection(gate_output_id, given_output_id, a_name, b_name)?; + ac.add_connection(gate_output_id, given_output_id)?; } DataType::Variable => { // Assign the evaluated right-hand side to the left-hand side @@ -180,10 +178,7 @@ pub fn process_statement( let component_signal = ctx.get_component_signal_id(&lh_access)?; let assigned_signal = get_signal_for_access(ac, ctx, &rh_access)?; - let a_name = lh_access.access_str(ctx.get_ctx_name()); - let b_name = rh_access.access_str(ctx.get_ctx_name()); - - ac.add_connection(assigned_signal, component_signal, a_name, b_name)?; + ac.add_connection(assigned_signal, component_signal)?; } _ => return Err(ProgramError::OperationNotSupported), }, @@ -204,7 +199,7 @@ pub fn process_statement( Ok(()) } - _ => todo!() + _ => todo!(), } } @@ -235,7 +230,7 @@ pub fn process_expression( Expression::Variable { name, access, .. } => { build_access(ac, runtime, program_archive, name, access) } - _ => todo!() + _ => todo!(), } } @@ -376,10 +371,8 @@ fn handle_infix_op( match op { ExpressionInfixOpcode::Lesser => { println!("DEBUG ALt"); - }, - _ => { - } + _ => {} } // Handle cases where one or both inputs are signals @@ -394,11 +387,7 @@ fn handle_infix_op( // Add output signal and gate to the circuit ac.add_signal(output_id, output_signal.access_str(ctx.get_ctx_name()))?; - let lh_name = lhe_access.access_str(ctx.get_ctx_name()); - let rh_name = rhe_access.access_str(ctx.get_ctx_name()); - let o_name = output_signal.access_str(ctx.get_ctx_name()); - - ac.add_gate(gate_type, lhs_id, rhs_id, output_id, lh_name, rh_name, o_name)?; + ac.add_gate(gate_type, lhs_id, rhs_id, output_id)?; Ok(output_signal) }