Skip to content

Commit

Permalink
fix: number operations (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
brech1 authored Feb 24, 2024
1 parent e9b327c commit 65c01e3
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 33 deletions.
99 changes: 69 additions & 30 deletions src/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::circuit::{AGateType, ArithmeticCircuit};
use crate::program::ProgramError;
use crate::runtime::{
increment_indices, u32_to_access, Context, DataAccess, DataType, Runtime, Signal, SubAccess,
RETURN_VAR,
};
use circom_circom_algebra::num_traits::ToPrimitive;
use circom_program_structure::ast::{
Expand All @@ -14,8 +15,6 @@ use circom_program_structure::ast::{
use circom_program_structure::program_archive::ProgramArchive;
use std::collections::HashMap;

pub const RETURN_VAR: &str = "function_return";

/// Processes a sequence of statements.
pub fn process_statements(
ac: &mut ArithmeticCircuit,
Expand Down Expand Up @@ -189,18 +188,15 @@ pub fn process_statement(
Ok(())
}
Statement::Return { value, .. } => {
let return_var_access = DataAccess::new(RETURN_VAR, vec![]);
let return_access = process_expression(ac, runtime, program_archive, value)?;
let return_value = runtime
.current_context()?

let ctx = runtime.current_context()?;
let return_value = ctx
.get_variable_value(&return_access)?
.ok_or(ProgramError::EmptyDataItem)?;

let ctx = runtime.current_context()?;
if ctx.get_variable_value(&return_var_access).is_err() {
ctx.declare_item(DataType::Variable, RETURN_VAR, &[])?;
}
ctx.set_variable(&return_var_access, Some(return_value))?;
ctx.declare_item(DataType::Variable, RETURN_VAR, &[])?;
ctx.set_variable(&DataAccess::new(RETURN_VAR, vec![]), Some(return_value))?;

Ok(())
}
Expand Down Expand Up @@ -363,7 +359,7 @@ fn handle_infix_op(
.get_variable_value(&rhe_access)?
.ok_or(ProgramError::EmptyDataItem)?;

let op_res = execute_op(&lhs_value, &rhs_value, op);
let op_res = execute_op(lhs_value, rhs_value, op)?;
let item_access = ctx.declare_random_item(DataType::Variable)?;
ctx.set_variable(&item_access, Some(op_res))?;

Expand Down Expand Up @@ -436,55 +432,98 @@ pub fn build_access(
Ok(DataAccess::new(name, access_vec))
}

/// Executes an operation, performing the specified arithmetic or logical computation.
pub fn execute_op(lhs: &u32, rhs: &u32, op: &ExpressionInfixOpcode) -> u32 {
match AGateType::from(op) {
AGateType::AAdd => lhs + rhs,
AGateType::ADiv => lhs / rhs,
AGateType::AEq => {
if lhs == rhs {
/// Executes an operation on two u32 values, performing the specified arithmetic or logical computation.
pub fn execute_op(lhs: u32, rhs: u32, op: &ExpressionInfixOpcode) -> Result<u32, ProgramError> {
let res = match op {
ExpressionInfixOpcode::Mul => lhs * rhs,
ExpressionInfixOpcode::Div => {
if rhs == 0 {
return Err(ProgramError::OperationError("Division by zero".to_string()));
}

lhs / rhs
}
ExpressionInfixOpcode::Add => lhs + rhs,
ExpressionInfixOpcode::Sub => lhs - rhs,
ExpressionInfixOpcode::Pow => lhs.pow(rhs),
ExpressionInfixOpcode::IntDiv => {
if rhs == 0 {
return Err(ProgramError::OperationError(
"Integer division by zero".to_string(),
));
}

lhs / rhs
}
ExpressionInfixOpcode::Mod => {
if rhs == 0 {
return Err(ProgramError::OperationError("Modulo by zero".to_string()));
}

lhs % rhs
}
ExpressionInfixOpcode::ShiftL => lhs << rhs,
ExpressionInfixOpcode::ShiftR => lhs >> rhs,
ExpressionInfixOpcode::LesserEq => {
if lhs <= rhs {
1
} else {
0
}
}
AGateType::AGEq => {
ExpressionInfixOpcode::GreaterEq => {
if lhs >= rhs {
1
} else {
0
}
}
AGateType::AGt => {
if lhs > rhs {
ExpressionInfixOpcode::Lesser => {
if lhs < rhs {
1
} else {
0
}
}
AGateType::ALEq => {
if lhs <= rhs {
ExpressionInfixOpcode::Greater => {
if lhs > rhs {
1
} else {
0
}
}
AGateType::ALt => {
if lhs < rhs {
ExpressionInfixOpcode::Eq => {
if lhs == rhs {
1
} else {
0
}
}
AGateType::AMul => lhs * rhs,
AGateType::ANeq => {
ExpressionInfixOpcode::NotEq => {
if lhs != rhs {
1
} else {
0
}
}
AGateType::ANone => unimplemented!(),
AGateType::ASub => lhs - rhs,
}
ExpressionInfixOpcode::BoolOr => {
if lhs != 0 || rhs != 0 {
1
} else {
0
}
}
ExpressionInfixOpcode::BoolAnd => {
if lhs != 0 && rhs != 0 {
1
} else {
0
}
}
ExpressionInfixOpcode::BitOr => lhs | rhs,
ExpressionInfixOpcode::BitAnd => lhs & rhs,
ExpressionInfixOpcode::BitXor => lhs ^ rhs,
};

Ok(res)
}
7 changes: 4 additions & 3 deletions src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,16 @@ pub enum ProgramError {
IOError(#[from] io::Error),
#[error("JSON serialization error: {0}")]
JsonSerializationError(#[from] serde_json::Error),
#[error("Output directory creation error")]
OutputDirectoryCreationError,
#[error("Operation error: {0}")]
OperationError(String),
#[error("Operation not supported")]
OperationNotSupported,
#[error("Output directory creation error")]
OutputDirectoryCreationError,
#[error("Parsing error")]
ParsingError,
#[error("Runtime error: {0}")]
RuntimeError(RuntimeError),
#[error("Undefined function or template")]
UndefinedFunctionOrTemplate,
}

8 changes: 8 additions & 0 deletions src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use rand::{thread_rng, Rng};
use std::collections::{HashMap, HashSet, VecDeque};
use thiserror::Error;

pub const RETURN_VAR: &str = "function_return_value";

/// Data type
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DataType {
Expand Down Expand Up @@ -154,6 +156,12 @@ impl Context {
}
}

// Force the merge of the return variable.
if child.variables.contains_key(RETURN_VAR) {
self.variables
.insert(RETURN_VAR.to_string(), child.variables[RETURN_VAR].clone());
}

for (name, component) in &child.components {
if self.components.contains_key(name) {
self.components.insert(name.clone(), component.clone());
Expand Down

0 comments on commit 65c01e3

Please sign in to comment.