From 70754e902459a306edec7459d9e65be750608094 Mon Sep 17 00:00:00 2001 From: Brechy Date: Mon, 29 Apr 2024 22:21:56 -0300 Subject: [PATCH] feat: sim circuit --- Cargo.toml | 7 ++++--- src/circuit.rs | 49 ++++++++++++++++++++++++++++++++++++++++--- tests/mat_elem_mul.rs | 5 ++--- tests/sum.rs | 5 ++--- 4 files changed, 54 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5533001..395f144 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ rand = "0.8.5" regex = "1.10.3" serde_json = "1.0" serde = { version = "1.0.196", features = ["derive"] } -thiserror = "1.0.56" +thiserror = "1.0.59" # DSL circom-circom_algebra = { git = "https://github.com/iden3/circom", package = "circom_algebra"} @@ -29,8 +29,9 @@ circom-program_structure = { git = "https://github.com/iden3/circom", package = circom-type_analysis = { git = "https://github.com/iden3/circom", package = "type_analysis"} # MPZ -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", package = "mpz-circuits"} -bmr16-mpz = { git = "https://github.com/tkmct/mpz", package = "mpz-circuits"} +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", package = "mpz-circuits" } +bmr16-mpz = { git = "https://github.com/tkmct/mpz", package = "mpz-circuits" } +sim-circuit = { git = "https://github.com/brech1/sim-circuit" } [[bin]] name="circom" diff --git a/src/circuit.rs b/src/circuit.rs index 9fd0ada..9896101 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -15,6 +15,7 @@ use bmr16_mpz::{ use circom_program_structure::ast::ExpressionInfixOpcode; use log::debug; use serde::{Deserialize, Serialize}; +use sim_circuit::circuit::{Circuit as SimCircuit, Gate as SimGate, Node as SimNode, Operation}; use std::collections::HashMap; use thiserror::Error; @@ -30,7 +31,6 @@ pub enum AGateType { ALt, AMul, ANeq, - ANone, ASub, } @@ -47,7 +47,24 @@ impl From<&ExpressionInfixOpcode> for AGateType { ExpressionInfixOpcode::Mul => AGateType::AMul, ExpressionInfixOpcode::NotEq => AGateType::ANeq, ExpressionInfixOpcode::Sub => AGateType::ASub, - _ => AGateType::ANone, + _ => unimplemented!("Unsupported opcode"), + } + } +} + +impl From<&AGateType> for Operation { + fn from(gate: &AGateType) -> Self { + match gate { + AGateType::AAdd => Operation::Add, + AGateType::ASub => Operation::Subtract, + AGateType::AMul => Operation::Multiply, + AGateType::ADiv => Operation::Divide, + AGateType::AEq => Operation::Equals, + AGateType::ANeq => Operation::NotEquals, + AGateType::ALt => Operation::LessThan, + AGateType::ALEq => Operation::LessOrEqual, + AGateType::AGt => Operation::GreaterThan, + AGateType::AGEq => Operation::GreaterOrEqual, } } } @@ -251,7 +268,6 @@ impl ArithmeticCircuit { if node_a_id == node_b_id { return Ok(()); } - // Check for output and constant nodes if node_a.is_out && node_b.is_out { return Err(CircuitError::CannotMergeOutputNodes); @@ -413,6 +429,33 @@ impl ArithmeticCircuit { .map_err(|_| CircuitError::MPZCircuitBuilderError) } + /// Builds a sim circuit instance. + pub fn build_sim_circuit(&self) -> Result { + let mut sim_circuit = SimCircuit::new(); + + // Add nodes + for (&id, node) in &self.nodes { + let mut new_node = SimNode::new(); + if let Some(value) = node + .signals + .first() + .and_then(|&sig_id| self.signals.get(&sig_id).and_then(|sig| sig.value)) + { + new_node.set_value(value); + } + sim_circuit.add_node(id, new_node); + } + + // Add gates + for gate in &self.gates { + let operation = Operation::from(&gate.op); + let sim_gate = SimGate::new(operation, gate.lh_in, gate.rh_in, gate.out); + sim_circuit.add_gate(sim_gate); + } + + Ok(sim_circuit) + } + /// Returns a node id and increments the count. fn get_node_id(&mut self) -> u32 { self.node_count += 1; diff --git a/tests/mat_elem_mul.rs b/tests/mat_elem_mul.rs index ba597cb..a11b3db 100644 --- a/tests/mat_elem_mul.rs +++ b/tests/mat_elem_mul.rs @@ -6,10 +6,9 @@ const TEST_FILE_PATH: &str = "./tests/circuits/matElemMul.circom"; fn test_matrix_element_multiplication() { let input = Input::new(TEST_FILE_PATH.into(), "./".into()).unwrap(); let circuit = build_circuit(&input).unwrap(); - let report = circuit.generate_circuit_report().unwrap(); - let mpz_circuit = circuit.build_mpz_circuit(&report).unwrap(); + let mut sim_circuit = circuit.build_sim_circuit().unwrap(); let circuit_input = vec![2, 2, 2, 2, 2, 2, 2, 2]; - let res = mpz_circuit.evaluate(&circuit_input).unwrap(); + let res = sim_circuit.execute(&circuit_input).unwrap(); assert_eq!(res, vec![4, 4, 4, 4]); } diff --git a/tests/sum.rs b/tests/sum.rs index aed59d3..fea1c40 100644 --- a/tests/sum.rs +++ b/tests/sum.rs @@ -6,10 +6,9 @@ const TEST_FILE_PATH: &str = "./tests/circuits/sum.circom"; fn test_sum() { let input = Input::new(TEST_FILE_PATH.into(), "./".into()).unwrap(); let circuit = build_circuit(&input).unwrap(); - let report = circuit.generate_circuit_report().unwrap(); - let mpz_circuit = circuit.build_mpz_circuit(&report).unwrap(); + let mut sim_circuit = circuit.build_sim_circuit().unwrap(); let circuit_input = vec![1, 2]; - let res = mpz_circuit.evaluate(&circuit_input).unwrap(); + let res = sim_circuit.execute(&circuit_input).unwrap(); assert_eq!(res, vec![3]); }