From 655394038e97732f5796f22116e92a265f3ba42d Mon Sep 17 00:00:00 2001 From: Brechy <113475764+brech1@users.noreply.github.com> Date: Wed, 27 Mar 2024 21:50:59 -0300 Subject: [PATCH] fix: signals and nodes id (#40) --- src/circuit.rs | 26 +++++++++++++++++--------- src/process.rs | 32 ++++++++++++++++++++++++-------- src/runtime.rs | 49 ++++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 81 insertions(+), 26 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 768b179..c05626c 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -2,7 +2,7 @@ //! //! This module defines the data structures used to represent the arithmetic circuit. -use crate::{program::ProgramError, runtime::generate_u32}; +use crate::program::ProgramError; use circom_program_structure::ast::ExpressionInfixOpcode; use log::debug; use mpz_circuits::GateType; @@ -54,9 +54,9 @@ pub struct Node { impl Node { /// Creates a new node. - pub fn new(signal_id: u32) -> Self { + pub fn new(id: u32, signal_id: u32) -> Self { Self { - id: generate_u32(), + id, signals: vec![signal_id], } } @@ -72,9 +72,9 @@ impl Node { } /// Merges the signals of the node with another node, creating a new node. - pub fn merge(&self, merge_node: &Node) -> Self { - let mut new_node = Node { - id: generate_u32(), + pub fn merge(&self, merge_node: &Node, id: u32) -> Self { + let mut new_node = Self { + id, signals: Vec::new(), }; @@ -119,6 +119,7 @@ pub struct ArithmeticCircuit { vars: HashMap>, nodes: Vec, gates: Vec, + node_count: u32, } impl ArithmeticCircuit { @@ -128,6 +129,7 @@ impl ArithmeticCircuit { vars: HashMap::new(), nodes: Vec::new(), gates: Vec::new(), + node_count: 0, } } @@ -140,7 +142,7 @@ impl ArithmeticCircuit { self.vars.insert(id, None); // Create a new node for the signal - let node = Node::new(id); + let node = Node::new(self.get_node_id(), id); debug!("New {:?}", node); self.nodes.push(node); @@ -156,7 +158,7 @@ impl ArithmeticCircuit { self.vars.insert(value, Some(value)); // Create a new node for the constant - let node = Node::new(value); + let node = Node::new(self.get_node_id(), value); debug!("New {:?}", node); self.nodes.push(node); @@ -221,7 +223,7 @@ impl ArithmeticCircuit { } // Merge the nodes - let merged_node = node_a.merge(&node_b); + let merged_node = node_a.merge(&node_b, self.get_node_id()); // Update connections in gates to point to the new merged node self.gates.iter_mut().for_each(|gate| { @@ -264,6 +266,12 @@ impl ArithmeticCircuit { pub fn gate_count(&self) -> u32 { self.gates.len() as u32 } + + /// Generates a new node id + fn get_node_id(&mut self) -> u32 { + self.node_count += 1; + self.node_count + } } #[allow(dead_code)] diff --git a/src/process.rs b/src/process.rs index f07e770..01c1474 100644 --- a/src/process.rs +++ b/src/process.rs @@ -60,6 +60,7 @@ pub fn process_statement( .map(|expression| process_expression(ac, runtime, program_archive, expression)) .collect::, ProgramError>>()?; + let signal_gen = runtime.get_signal_gen(); let ctx = runtime.current_context()?; let dimensions: Vec = dim_access .iter() @@ -68,7 +69,7 @@ pub fn process_statement( .ok_or(ProgramError::EmptyDataItem) }) .collect::, ProgramError>>()?; - ctx.declare_item(data_type.clone(), name, &dimensions)?; + ctx.declare_item(data_type.clone(), name, &dimensions, signal_gen)?; // If the declared item is a signal we should add it to the arithmetic circuit if data_type == DataType::Signal { @@ -148,12 +149,13 @@ pub fn process_statement( Statement::Return { value, .. } => { let return_access = process_expression(ac, runtime, program_archive, value)?; + let signal_gen = runtime.get_signal_gen(); let ctx = runtime.current_context()?; let return_value = ctx .get_variable_value(&return_access)? .ok_or(ProgramError::EmptyDataItem)?; - ctx.declare_item(DataType::Variable, RETURN_VAR, &[])?; + ctx.declare_item(DataType::Variable, RETURN_VAR, &[], signal_gen)?; ctx.set_variable(&DataAccess::new(RETURN_VAR, vec![]), Some(return_value))?; Ok(()) @@ -258,9 +260,10 @@ pub fn process_expression( lhe, infix_op, rhe, .. } => handle_infix_op(ac, runtime, program_archive, infix_op, lhe, rhe), Expression::Number(_, value) => { + let signal_gen = runtime.get_signal_gen(); let access = runtime .current_context()? - .declare_random_item(DataType::Variable)?; + .declare_random_item(signal_gen, DataType::Variable)?; runtime.current_context()?.set_variable( &access, @@ -319,9 +322,10 @@ fn handle_call( // Set arguments in the new context for (arg_name, &arg_value) in arg_names.iter().zip(&arg_values) { + let signal_gen = runtime.get_signal_gen(); runtime .current_context()? - .declare_item(DataType::Variable, arg_name, &[])?; + .declare_item(DataType::Variable, arg_name, &[], signal_gen)?; runtime .current_context()? .set_variable(&DataAccess::new(arg_name, vec![]), Some(arg_value))?; @@ -356,15 +360,26 @@ fn handle_call( // Return to parent context runtime.pop_context(false)?; + let signal_gen = runtime.get_signal_gen(); let ctx = runtime.current_context()?; let return_access = DataAccess::new(&format!("{}_{}_{}", id, RETURN_VAR, generate_u32()), vec![]); if is_function { - ctx.declare_item(DataType::Variable, &return_access.get_name(), &[])?; + ctx.declare_item( + DataType::Variable, + &return_access.get_name(), + &[], + signal_gen, + )?; ctx.set_variable(&return_access, function_return)?; } else { - ctx.declare_item(DataType::Component, &return_access.get_name(), &[])?; + ctx.declare_item( + DataType::Component, + &return_access.get_name(), + &[], + signal_gen, + )?; ctx.set_component(&return_access, component_return)?; } @@ -386,6 +401,7 @@ fn handle_infix_op( let lhe_access = process_expression(ac, runtime, program_archive, lhe)?; let rhe_access = process_expression(ac, runtime, program_archive, rhe)?; + let signal_gen = runtime.get_signal_gen(); let ctx = runtime.current_context()?; // Determine the data types of the left and right operands @@ -402,7 +418,7 @@ fn handle_infix_op( .ok_or(ProgramError::EmptyDataItem)?; let op_res = execute_op(lhs_value, rhs_value, op)?; - let item_access = ctx.declare_random_item(DataType::Variable)?; + let item_access = ctx.declare_random_item(signal_gen, DataType::Variable)?; ctx.set_variable(&item_access, Some(op_res))?; return Ok(item_access); @@ -414,7 +430,7 @@ fn handle_infix_op( // Construct the corresponding circuit gate let gate_type = AGateType::from(op); - let output_signal = ctx.declare_random_item(DataType::Signal)?; + let output_signal = ctx.declare_random_item(signal_gen, DataType::Signal)?; let output_id = ctx.get_signal_id(&output_signal)?; // Add output signal and gate to the circuit diff --git a/src/runtime.rs b/src/runtime.rs index 67b2cb7..29de747 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -5,7 +5,11 @@ use crate::program::ProgramError; use circom_program_structure::ast::VariableType; use rand::{thread_rng, Rng}; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::{ + cell::RefCell, + collections::{HashMap, HashSet, VecDeque}, + rc::Rc, +}; use thiserror::Error; pub const RETURN_VAR: &str = "function_return_value"; @@ -49,6 +53,7 @@ pub enum SubAccess { /// Manages a stack of execution contexts for a runtime environment. pub struct Runtime { contexts: VecDeque, + next_signal_id: Rc>, } impl Default for Runtime { @@ -62,6 +67,7 @@ impl Runtime { pub fn new() -> Self { Self { contexts: VecDeque::from([Context::new()]), + next_signal_id: Rc::new(RefCell::new(0)), } } @@ -108,6 +114,19 @@ impl Runtime { .front_mut() .ok_or(RuntimeError::EmptyContextStack) } + + /// Returns a clone of the Rc> for next_signal_id. + pub fn get_signal_gen(&self) -> Rc> { + Rc::clone(&self.next_signal_id) + } + + /// Generates a new unique signal ID. + fn gen_signal(next_signal_id: Rc>) -> u32 { + let mut id_ref = next_signal_id.borrow_mut(); + let id = *id_ref; + *id_ref += 1; + id + } } /// Context @@ -177,6 +196,7 @@ impl Context { data_type: DataType, name: &str, dimensions: &[u32], + next_signal_id: Rc>, ) -> Result<(), RuntimeError> { // Parse name let name = name.to_string(); @@ -188,7 +208,7 @@ impl Context { match data_type { DataType::Signal => { - let signal = Signal::new(dimensions); + let signal = Signal::new(dimensions, next_signal_id); self.signals.insert(name, signal); } DataType::Variable => { @@ -205,9 +225,13 @@ impl Context { } /// Declares a new item with a random name. - pub fn declare_random_item(&mut self, data_type: DataType) -> Result { + pub fn declare_random_item( + &mut self, + next_signal_id: Rc>, + data_type: DataType, + ) -> Result { let name = format!("random_{}", generate_u32()); - self.declare_item(data_type, &name, &[])?; + self.declare_item(data_type, &name, &[], next_signal_id)?; Ok(DataAccess::new(&name, vec![])) } @@ -407,18 +431,25 @@ pub struct Signal { impl Signal { /// Constructs a new Signal as a nested structure based on provided dimensions. - fn new(dimensions: &[u32]) -> Self { - fn create_nested_signal(dimensions: &[u32]) -> NestedValue { + fn new(dimensions: &[u32], next_signal_id: Rc>) -> Self { + fn create_nested_signal( + dimensions: &[u32], + next_signal_id: Rc>, + ) -> NestedValue { if let Some((&first, rest)) = dimensions.split_first() { - let array = (0..first).map(|_| create_nested_signal(rest)).collect(); + let array = (0..first) + .map(|_| create_nested_signal(rest, next_signal_id.clone())) + .collect(); NestedValue::Array(array) } else { - NestedValue::Value(generate_u32()) + // Generate a new signal ID + let id = Runtime::gen_signal(next_signal_id); + NestedValue::Value(id) } } Self { - value: create_nested_signal(dimensions), + value: create_nested_signal(dimensions, next_signal_id), } }