Skip to content

Commit

Permalink
fix: signals and nodes id (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
brech1 authored Mar 28, 2024
1 parent 530613e commit 6553940
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 26 deletions.
26 changes: 17 additions & 9 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! This module defines the data structures used to represent the arithmetic circuit.

use crate::{program::ProgramError, runtime::generate_u32};
use crate::program::ProgramError;
use circom_program_structure::ast::ExpressionInfixOpcode;
use log::debug;
use mpz_circuits::GateType;
Expand Down Expand Up @@ -54,9 +54,9 @@ pub struct Node {

impl Node {
/// Creates a new node.
pub fn new(signal_id: u32) -> Self {
pub fn new(id: u32, signal_id: u32) -> Self {
Self {
id: generate_u32(),
id,
signals: vec![signal_id],
}
}
Expand All @@ -72,9 +72,9 @@ impl Node {
}

/// Merges the signals of the node with another node, creating a new node.
pub fn merge(&self, merge_node: &Node) -> Self {
let mut new_node = Node {
id: generate_u32(),
pub fn merge(&self, merge_node: &Node, id: u32) -> Self {
let mut new_node = Self {
id,
signals: Vec::new(),
};

Expand Down Expand Up @@ -119,6 +119,7 @@ pub struct ArithmeticCircuit {
vars: HashMap<u32, Option<u32>>,
nodes: Vec<Node>,
gates: Vec<ArithmeticGate>,
node_count: u32,
}

impl ArithmeticCircuit {
Expand All @@ -128,6 +129,7 @@ impl ArithmeticCircuit {
vars: HashMap::new(),
nodes: Vec::new(),
gates: Vec::new(),
node_count: 0,
}
}

Expand All @@ -140,7 +142,7 @@ impl ArithmeticCircuit {
self.vars.insert(id, None);

// Create a new node for the signal
let node = Node::new(id);
let node = Node::new(self.get_node_id(), id);
debug!("New {:?}", node);

self.nodes.push(node);
Expand All @@ -156,7 +158,7 @@ impl ArithmeticCircuit {
self.vars.insert(value, Some(value));

// Create a new node for the constant
let node = Node::new(value);
let node = Node::new(self.get_node_id(), value);
debug!("New {:?}", node);

self.nodes.push(node);
Expand Down Expand Up @@ -221,7 +223,7 @@ impl ArithmeticCircuit {
}

// Merge the nodes
let merged_node = node_a.merge(&node_b);
let merged_node = node_a.merge(&node_b, self.get_node_id());

// Update connections in gates to point to the new merged node
self.gates.iter_mut().for_each(|gate| {
Expand Down Expand Up @@ -264,6 +266,12 @@ impl ArithmeticCircuit {
pub fn gate_count(&self) -> u32 {
self.gates.len() as u32
}

/// Generates a new node id
fn get_node_id(&mut self) -> u32 {
self.node_count += 1;
self.node_count
}
}

#[allow(dead_code)]
Expand Down
32 changes: 24 additions & 8 deletions src/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pub fn process_statement(
.map(|expression| process_expression(ac, runtime, program_archive, expression))
.collect::<Result<Vec<DataAccess>, ProgramError>>()?;

let signal_gen = runtime.get_signal_gen();
let ctx = runtime.current_context()?;
let dimensions: Vec<u32> = dim_access
.iter()
Expand All @@ -68,7 +69,7 @@ pub fn process_statement(
.ok_or(ProgramError::EmptyDataItem)
})
.collect::<Result<Vec<u32>, ProgramError>>()?;
ctx.declare_item(data_type.clone(), name, &dimensions)?;
ctx.declare_item(data_type.clone(), name, &dimensions, signal_gen)?;

// If the declared item is a signal we should add it to the arithmetic circuit
if data_type == DataType::Signal {
Expand Down Expand Up @@ -148,12 +149,13 @@ pub fn process_statement(
Statement::Return { value, .. } => {
let return_access = process_expression(ac, runtime, program_archive, value)?;

let signal_gen = runtime.get_signal_gen();
let ctx = runtime.current_context()?;
let return_value = ctx
.get_variable_value(&return_access)?
.ok_or(ProgramError::EmptyDataItem)?;

ctx.declare_item(DataType::Variable, RETURN_VAR, &[])?;
ctx.declare_item(DataType::Variable, RETURN_VAR, &[], signal_gen)?;
ctx.set_variable(&DataAccess::new(RETURN_VAR, vec![]), Some(return_value))?;

Ok(())
Expand Down Expand Up @@ -258,9 +260,10 @@ pub fn process_expression(
lhe, infix_op, rhe, ..
} => handle_infix_op(ac, runtime, program_archive, infix_op, lhe, rhe),
Expression::Number(_, value) => {
let signal_gen = runtime.get_signal_gen();
let access = runtime
.current_context()?
.declare_random_item(DataType::Variable)?;
.declare_random_item(signal_gen, DataType::Variable)?;

runtime.current_context()?.set_variable(
&access,
Expand Down Expand Up @@ -319,9 +322,10 @@ fn handle_call(

// Set arguments in the new context
for (arg_name, &arg_value) in arg_names.iter().zip(&arg_values) {
let signal_gen = runtime.get_signal_gen();
runtime
.current_context()?
.declare_item(DataType::Variable, arg_name, &[])?;
.declare_item(DataType::Variable, arg_name, &[], signal_gen)?;
runtime
.current_context()?
.set_variable(&DataAccess::new(arg_name, vec![]), Some(arg_value))?;
Expand Down Expand Up @@ -356,15 +360,26 @@ fn handle_call(

// Return to parent context
runtime.pop_context(false)?;
let signal_gen = runtime.get_signal_gen();
let ctx = runtime.current_context()?;
let return_access =
DataAccess::new(&format!("{}_{}_{}", id, RETURN_VAR, generate_u32()), vec![]);

if is_function {
ctx.declare_item(DataType::Variable, &return_access.get_name(), &[])?;
ctx.declare_item(
DataType::Variable,
&return_access.get_name(),
&[],
signal_gen,
)?;
ctx.set_variable(&return_access, function_return)?;
} else {
ctx.declare_item(DataType::Component, &return_access.get_name(), &[])?;
ctx.declare_item(
DataType::Component,
&return_access.get_name(),
&[],
signal_gen,
)?;
ctx.set_component(&return_access, component_return)?;
}

Expand All @@ -386,6 +401,7 @@ fn handle_infix_op(
let lhe_access = process_expression(ac, runtime, program_archive, lhe)?;
let rhe_access = process_expression(ac, runtime, program_archive, rhe)?;

let signal_gen = runtime.get_signal_gen();
let ctx = runtime.current_context()?;

// Determine the data types of the left and right operands
Expand All @@ -402,7 +418,7 @@ fn handle_infix_op(
.ok_or(ProgramError::EmptyDataItem)?;

let op_res = execute_op(lhs_value, rhs_value, op)?;
let item_access = ctx.declare_random_item(DataType::Variable)?;
let item_access = ctx.declare_random_item(signal_gen, DataType::Variable)?;
ctx.set_variable(&item_access, Some(op_res))?;

return Ok(item_access);
Expand All @@ -414,7 +430,7 @@ fn handle_infix_op(

// Construct the corresponding circuit gate
let gate_type = AGateType::from(op);
let output_signal = ctx.declare_random_item(DataType::Signal)?;
let output_signal = ctx.declare_random_item(signal_gen, DataType::Signal)?;
let output_id = ctx.get_signal_id(&output_signal)?;

// Add output signal and gate to the circuit
Expand Down
49 changes: 40 additions & 9 deletions src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
use crate::program::ProgramError;
use circom_program_structure::ast::VariableType;
use rand::{thread_rng, Rng};
use std::collections::{HashMap, HashSet, VecDeque};
use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
rc::Rc,
};
use thiserror::Error;

pub const RETURN_VAR: &str = "function_return_value";
Expand Down Expand Up @@ -49,6 +53,7 @@ pub enum SubAccess {
/// Manages a stack of execution contexts for a runtime environment.
pub struct Runtime {
contexts: VecDeque<Context>,
next_signal_id: Rc<RefCell<u32>>,
}

impl Default for Runtime {
Expand All @@ -62,6 +67,7 @@ impl Runtime {
pub fn new() -> Self {
Self {
contexts: VecDeque::from([Context::new()]),
next_signal_id: Rc::new(RefCell::new(0)),
}
}

Expand Down Expand Up @@ -108,6 +114,19 @@ impl Runtime {
.front_mut()
.ok_or(RuntimeError::EmptyContextStack)
}

/// Returns a clone of the Rc<RefCell<u32>> for next_signal_id.
pub fn get_signal_gen(&self) -> Rc<RefCell<u32>> {
Rc::clone(&self.next_signal_id)
}

/// Generates a new unique signal ID.
fn gen_signal(next_signal_id: Rc<RefCell<u32>>) -> u32 {
let mut id_ref = next_signal_id.borrow_mut();
let id = *id_ref;
*id_ref += 1;
id
}
}

/// Context
Expand Down Expand Up @@ -177,6 +196,7 @@ impl Context {
data_type: DataType,
name: &str,
dimensions: &[u32],
next_signal_id: Rc<RefCell<u32>>,
) -> Result<(), RuntimeError> {
// Parse name
let name = name.to_string();
Expand All @@ -188,7 +208,7 @@ impl Context {

match data_type {
DataType::Signal => {
let signal = Signal::new(dimensions);
let signal = Signal::new(dimensions, next_signal_id);
self.signals.insert(name, signal);
}
DataType::Variable => {
Expand All @@ -205,9 +225,13 @@ impl Context {
}

/// Declares a new item with a random name.
pub fn declare_random_item(&mut self, data_type: DataType) -> Result<DataAccess, RuntimeError> {
pub fn declare_random_item(
&mut self,
next_signal_id: Rc<RefCell<u32>>,
data_type: DataType,
) -> Result<DataAccess, RuntimeError> {
let name = format!("random_{}", generate_u32());
self.declare_item(data_type, &name, &[])?;
self.declare_item(data_type, &name, &[], next_signal_id)?;
Ok(DataAccess::new(&name, vec![]))
}

Expand Down Expand Up @@ -407,18 +431,25 @@ pub struct Signal {

impl Signal {
/// Constructs a new Signal as a nested structure based on provided dimensions.
fn new(dimensions: &[u32]) -> Self {
fn create_nested_signal(dimensions: &[u32]) -> NestedValue<u32> {
fn new(dimensions: &[u32], next_signal_id: Rc<RefCell<u32>>) -> Self {
fn create_nested_signal(
dimensions: &[u32],
next_signal_id: Rc<RefCell<u32>>,
) -> NestedValue<u32> {
if let Some((&first, rest)) = dimensions.split_first() {
let array = (0..first).map(|_| create_nested_signal(rest)).collect();
let array = (0..first)
.map(|_| create_nested_signal(rest, next_signal_id.clone()))
.collect();
NestedValue::Array(array)
} else {
NestedValue::Value(generate_u32())
// Generate a new signal ID
let id = Runtime::gen_signal(next_signal_id);
NestedValue::Value(id)
}
}

Self {
value: create_nested_signal(dimensions),
value: create_nested_signal(dimensions, next_signal_id),
}
}

Expand Down

0 comments on commit 6553940

Please sign in to comment.