diff --git a/src/process.rs b/src/process.rs index 41e6884..20fba40 100644 --- a/src/process.rs +++ b/src/process.rs @@ -6,6 +6,7 @@ use crate::circuit::{AGateType, ArithmeticCircuit}; use crate::program::ProgramError; use crate::runtime::{ increment_indices, u32_to_access, Context, DataAccess, DataType, Runtime, Signal, SubAccess, + RETURN_VAR, }; use circom_circom_algebra::num_traits::ToPrimitive; use circom_program_structure::ast::{ @@ -14,8 +15,6 @@ use circom_program_structure::ast::{ use circom_program_structure::program_archive::ProgramArchive; use std::collections::HashMap; -pub const RETURN_VAR: &str = "function_return"; - /// Processes a sequence of statements. pub fn process_statements( ac: &mut ArithmeticCircuit, @@ -189,18 +188,15 @@ pub fn process_statement( Ok(()) } Statement::Return { value, .. } => { - let return_var_access = DataAccess::new(RETURN_VAR, vec![]); let return_access = process_expression(ac, runtime, program_archive, value)?; - let return_value = runtime - .current_context()? + + let ctx = runtime.current_context()?; + let return_value = ctx .get_variable_value(&return_access)? .ok_or(ProgramError::EmptyDataItem)?; - let ctx = runtime.current_context()?; - if ctx.get_variable_value(&return_var_access).is_err() { - ctx.declare_item(DataType::Variable, RETURN_VAR, &[])?; - } - ctx.set_variable(&return_var_access, Some(return_value))?; + ctx.declare_item(DataType::Variable, RETURN_VAR, &[])?; + ctx.set_variable(&DataAccess::new(RETURN_VAR, vec![]), Some(return_value))?; Ok(()) } @@ -363,7 +359,7 @@ fn handle_infix_op( .get_variable_value(&rhe_access)? .ok_or(ProgramError::EmptyDataItem)?; - let op_res = execute_op(&lhs_value, &rhs_value, op); + let op_res = execute_op(lhs_value, rhs_value, op)?; let item_access = ctx.declare_random_item(DataType::Variable)?; ctx.set_variable(&item_access, Some(op_res))?; @@ -436,55 +432,98 @@ pub fn build_access( Ok(DataAccess::new(name, access_vec)) } -/// Executes an operation, performing the specified arithmetic or logical computation. -pub fn execute_op(lhs: &u32, rhs: &u32, op: &ExpressionInfixOpcode) -> u32 { - match AGateType::from(op) { - AGateType::AAdd => lhs + rhs, - AGateType::ADiv => lhs / rhs, - AGateType::AEq => { - if lhs == rhs { +/// Executes an operation on two u32 values, performing the specified arithmetic or logical computation. +pub fn execute_op(lhs: u32, rhs: u32, op: &ExpressionInfixOpcode) -> Result { + let res = match op { + ExpressionInfixOpcode::Mul => lhs * rhs, + ExpressionInfixOpcode::Div => { + if rhs == 0 { + return Err(ProgramError::OperationError("Division by zero".to_string())); + } + + lhs / rhs + } + ExpressionInfixOpcode::Add => lhs + rhs, + ExpressionInfixOpcode::Sub => lhs - rhs, + ExpressionInfixOpcode::Pow => lhs.pow(rhs), + ExpressionInfixOpcode::IntDiv => { + if rhs == 0 { + return Err(ProgramError::OperationError( + "Integer division by zero".to_string(), + )); + } + + lhs / rhs + } + ExpressionInfixOpcode::Mod => { + if rhs == 0 { + return Err(ProgramError::OperationError("Modulo by zero".to_string())); + } + + lhs % rhs + } + ExpressionInfixOpcode::ShiftL => lhs << rhs, + ExpressionInfixOpcode::ShiftR => lhs >> rhs, + ExpressionInfixOpcode::LesserEq => { + if lhs <= rhs { 1 } else { 0 } } - AGateType::AGEq => { + ExpressionInfixOpcode::GreaterEq => { if lhs >= rhs { 1 } else { 0 } } - AGateType::AGt => { - if lhs > rhs { + ExpressionInfixOpcode::Lesser => { + if lhs < rhs { 1 } else { 0 } } - AGateType::ALEq => { - if lhs <= rhs { + ExpressionInfixOpcode::Greater => { + if lhs > rhs { 1 } else { 0 } } - AGateType::ALt => { - if lhs < rhs { + ExpressionInfixOpcode::Eq => { + if lhs == rhs { 1 } else { 0 } } - AGateType::AMul => lhs * rhs, - AGateType::ANeq => { + ExpressionInfixOpcode::NotEq => { if lhs != rhs { 1 } else { 0 } } - AGateType::ANone => unimplemented!(), - AGateType::ASub => lhs - rhs, - } + ExpressionInfixOpcode::BoolOr => { + if lhs != 0 || rhs != 0 { + 1 + } else { + 0 + } + } + ExpressionInfixOpcode::BoolAnd => { + if lhs != 0 && rhs != 0 { + 1 + } else { + 0 + } + } + ExpressionInfixOpcode::BitOr => lhs | rhs, + ExpressionInfixOpcode::BitAnd => lhs & rhs, + ExpressionInfixOpcode::BitXor => lhs ^ rhs, + }; + + Ok(res) } diff --git a/src/program.rs b/src/program.rs index a1733ce..9595b0e 100644 --- a/src/program.rs +++ b/src/program.rs @@ -48,10 +48,12 @@ pub enum ProgramError { IOError(#[from] io::Error), #[error("JSON serialization error: {0}")] JsonSerializationError(#[from] serde_json::Error), - #[error("Output directory creation error")] - OutputDirectoryCreationError, + #[error("Operation error: {0}")] + OperationError(String), #[error("Operation not supported")] OperationNotSupported, + #[error("Output directory creation error")] + OutputDirectoryCreationError, #[error("Parsing error")] ParsingError, #[error("Runtime error: {0}")] @@ -59,4 +61,3 @@ pub enum ProgramError { #[error("Undefined function or template")] UndefinedFunctionOrTemplate, } - diff --git a/src/runtime.rs b/src/runtime.rs index a2ea4fa..7f1bb25 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -8,6 +8,8 @@ use rand::{thread_rng, Rng}; use std::collections::{HashMap, HashSet, VecDeque}; use thiserror::Error; +pub const RETURN_VAR: &str = "function_return_value"; + /// Data type #[derive(Clone, Debug, PartialEq, Eq)] pub enum DataType { @@ -154,6 +156,12 @@ impl Context { } } + // Force the merge of the return variable. + if child.variables.contains_key(RETURN_VAR) { + self.variables + .insert(RETURN_VAR.to_string(), child.variables[RETURN_VAR].clone()); + } + for (name, component) in &child.components { if self.components.contains_key(name) { self.components.insert(name.clone(), component.clone());