Skip to content

Commit

Permalink
update circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
brech1 committed Mar 1, 2024
1 parent 3e778f2 commit 09423a3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 94 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,3 @@ mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", p
[[bin]]
name="circom"
path="./src/main.rs"

130 changes: 54 additions & 76 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,65 +45,58 @@ impl From<&ExpressionInfixOpcode> for AGateType {
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Signal {
id: u32,
name: String,
value: Option<u32>,
}

impl Signal {
pub fn new(id: u32, name: String, value: Option<u32>) -> Self {
Self { id, name, value }
}

pub fn is_const(&self) -> bool {
self.value.is_some()
}
}

/// Represents a node in the circuit, with an identifier and a set of signals that it is connected to.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node {
id: u32,
signals: Vec<u32>,
names: Vec<String>,
is_const: bool,
const_value: u32,
signals: Vec<Signal>,
}

impl Node {
/// Creates a new node.
pub fn new(signal_id: u32, signal_name: String, is_const: bool, const_value: u32) -> Self {
pub fn new(signal: Signal) -> Self {
Self {
id: generate_u32(),
signals: vec![signal_id],
names: vec![signal_name],
is_const,
const_value,
signals: vec![signal],
}
}

pub fn is_const(&self) -> bool {
self.is_const
}

pub fn const_value(&self) -> u32 {
self.const_value
}

/// Adds a set of signals to the node.
pub fn add_signals(&mut self, signals: Vec<u32>, names: Vec<String>) {
pub fn add_signals(&mut self, signals: Vec<Signal>) {
self.signals.extend(signals);
self.names.extend(names);
}

/// Gets the signals of the node.
pub fn get_signals(&self) -> Vec<u32> {
pub fn get_signals(&self) -> Vec<Signal> {
self.signals.clone()
}

pub fn get_signals_names(&self) -> Vec<String> {
self.names.clone()
}

/// Merges the signals of the node with another node, creating a new node.
pub fn merge(&self, merge_node: &Node) -> Self {
let ic = self.is_const() | merge_node.is_const();
let cv = self.const_value();
let mut new_node = Node {
id: generate_u32(),
signals: Vec::new(),
names: Vec::new(),
is_const: ic,
const_value: cv,
};

new_node.add_signals(self.get_signals(), self.get_signals_names());
new_node.add_signals(merge_node.get_signals(), merge_node.get_signals_names());
new_node.add_signals(self.get_signals());
new_node.add_signals(merge_node.get_signals());

new_node
}
Expand Down Expand Up @@ -140,7 +133,7 @@ impl ArithmeticGate {
/// Represents an arithmetic circuit, with a set of variables and gates.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct ArithmeticCircuit {
vars: HashMap<u32, Option<u32>>,
signals: HashMap<u32, Signal>,
nodes: Vec<Node>,
gates: Vec<ArithmeticGate>,
}
Expand All @@ -149,7 +142,7 @@ impl ArithmeticCircuit {
/// Creates a new arithmetic circuit.
pub fn new() -> ArithmeticCircuit {
ArithmeticCircuit {
vars: HashMap::new(),
signals: HashMap::new(),
nodes: Vec::new(),
gates: Vec::new(),
}
Expand All @@ -158,13 +151,18 @@ impl ArithmeticCircuit {
/// Adds a new signal variable to the circuit.
pub fn add_signal(&mut self, id: u32, name: String) -> Result<(), CircuitError> {
// Check that the variable isn't already declared
if self.contains_var(&id) {
if self.is_signal_declared(&id) {
return Err(CircuitError::CircuitVariableAlreadyDeclared);
}
self.vars.insert(id, None);

// Create a new Signal
let signal = Signal::new(id, name, None);

// Store the signal data
self.signals.insert(signal.id, signal.clone());

// Create a new node for the signal
let node = Node::new(id, name, false, 0);
let node = Node::new(signal);
debug!("New {:?}", node);

self.nodes.push(node);
Expand All @@ -174,13 +172,18 @@ impl ArithmeticCircuit {
/// Adds a new constant variable to the circuit.
pub fn add_const(&mut self, value: u32, name: String) -> Result<(), CircuitError> {
// Ignore if the constant is already declared
if self.contains_var(&value) {
if self.is_signal_declared(&value) {
return Ok(());
}
self.vars.insert(value, Some(value));

// Create a new constant Signal
let signal = Signal::new(value, name, Some(value));

// Store the signal data
self.signals.insert(signal.id, signal.clone());

// Create a new node for the constant
let node = Node::new(value, name, true, value);
let node = Node::new(signal);
debug!("New {:?}", node);

self.nodes.push(node);
Expand All @@ -194,40 +197,15 @@ impl ArithmeticCircuit {
lhs_id: u32,
rhs_id: u32,
output_id: u32,
lh_name: String,
rh_name: String,
o_name: String
) -> Result<(), CircuitError> {
// Check that the inputs are declared
if !self.contains_var(&lhs_id)
|| !self.contains_var(&rhs_id)
|| !self.contains_var(&output_id)
if !self.is_signal_declared(&lhs_id)
|| !self.is_signal_declared(&rhs_id)
|| !self.is_signal_declared(&output_id)
{
return Err(CircuitError::VariableNotDeclared);
}

match gate_type {
AGateType::AAdd => {
println!("{} = {} + {}", o_name, lh_name, rh_name);
},
AGateType::ADiv => todo!(),
AGateType::AEq => todo!(),
AGateType::AGEq => todo!(),
AGateType::AGt => todo!(),
AGateType::ALEq => todo!(),
AGateType::ALt => {
println!("{} = {} < {}", o_name, lh_name, rh_name);
},
AGateType::AMul => {
println!("{} = {} * {}", o_name, lh_name, rh_name);
},
AGateType::ANeq => todo!(),
AGateType::ANone => todo!(),
AGateType::ASub => {
println!("{} = {} - {}", o_name, lh_name, rh_name);
},
};

// Get the signal nodes
let lhs_node = self.get_signal_node(lhs_id)?;
let rhs_node = self.get_signal_node(rhs_id)?;
Expand All @@ -249,9 +227,9 @@ impl ArithmeticCircuit {

/// Creates a connection between two signals in the circuit.
/// This is done by finding the nodes that contain the signals and merging them.
pub fn add_connection(&mut self, a: u32, b: u32, a_name: String, b_name: String) -> Result<(), CircuitError> {
pub fn add_connection(&mut self, a: u32, b: u32) -> Result<(), CircuitError> {
// Check that the endpoints are declared
if !self.contains_var(&a) || !self.contains_var(&b) {
if !self.is_signal_declared(&a) || !self.is_signal_declared(&b) {
return Err(CircuitError::VariableNotDeclared);
}

Expand Down Expand Up @@ -290,25 +268,25 @@ impl ArithmeticCircuit {
.retain(|node| node.id != node_a.id && node.id != node_b.id);
self.nodes.push(merged_node);

println!("{} = {}", a_name, b_name);

Ok(())
}

/// Returns the node containing the given signal.
fn get_signal_node(&self, signal_id: u32) -> Result<Node, CircuitError> {
for node in &self.nodes {
if node.signals.contains(&signal_id) {
return Ok(node.clone());
for signal in &node.signals {
if signal.id == signal_id {
return Ok(node.clone());
}
}
}

Err(CircuitError::NodeNotFound)
}

/// Checks if the variable exists
pub fn contains_var(&self, var: &u32) -> bool {
self.vars.contains_key(var)
/// Checks if the signal exists
pub fn is_signal_declared(&self, id: &u32) -> bool {
self.signals.contains_key(id)
}

/// Returns the number of gates in the circuit.
Expand Down
23 changes: 6 additions & 17 deletions src/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ pub fn process_statement(
// Connect the generated gate output to the given signal
let given_output_id = ctx.get_signal_id(&lh_access)?;
let gate_output_id = get_signal_for_access(ac, ctx, &rh_access)?;
let a_name = lh_access.access_str(ctx.get_ctx_name());
let b_name = rh_access.access_str(ctx.get_ctx_name());
ac.add_connection(gate_output_id, given_output_id, a_name, b_name)?;
ac.add_connection(gate_output_id, given_output_id)?;
}
DataType::Variable => {
// Assign the evaluated right-hand side to the left-hand side
Expand All @@ -180,10 +178,7 @@ pub fn process_statement(
let component_signal = ctx.get_component_signal_id(&lh_access)?;
let assigned_signal = get_signal_for_access(ac, ctx, &rh_access)?;

let a_name = lh_access.access_str(ctx.get_ctx_name());
let b_name = rh_access.access_str(ctx.get_ctx_name());

ac.add_connection(assigned_signal, component_signal, a_name, b_name)?;
ac.add_connection(assigned_signal, component_signal)?;
}
_ => return Err(ProgramError::OperationNotSupported),
},
Expand All @@ -204,7 +199,7 @@ pub fn process_statement(

Ok(())
}
_ => todo!()
_ => todo!(),
}
}

Expand Down Expand Up @@ -235,7 +230,7 @@ pub fn process_expression(
Expression::Variable { name, access, .. } => {
build_access(ac, runtime, program_archive, name, access)
}
_ => todo!()
_ => todo!(),
}
}

Expand Down Expand Up @@ -376,10 +371,8 @@ fn handle_infix_op(
match op {
ExpressionInfixOpcode::Lesser => {
println!("DEBUG ALt");
},
_ => {

}
_ => {}
}

// Handle cases where one or both inputs are signals
Expand All @@ -394,11 +387,7 @@ fn handle_infix_op(
// Add output signal and gate to the circuit
ac.add_signal(output_id, output_signal.access_str(ctx.get_ctx_name()))?;

let lh_name = lhe_access.access_str(ctx.get_ctx_name());
let rh_name = rhe_access.access_str(ctx.get_ctx_name());
let o_name = output_signal.access_str(ctx.get_ctx_name());

ac.add_gate(gate_type, lhs_id, rhs_id, output_id, lh_name, rh_name, o_name)?;
ac.add_gate(gate_type, lhs_id, rhs_id, output_id)?;

Ok(output_signal)
}
Expand Down

0 comments on commit 09423a3

Please sign in to comment.