Skip to content

Commit

Permalink
Merge pull request #59 from namnc/issue-54-all-infix-ops
Browse files Browse the repository at this point in the history
Support ALL circom InfixOp
  • Loading branch information
voltrevo authored Jun 4, 2024
2 parents f77b082 + 48df5a2 commit 5acb088
Show file tree
Hide file tree
Showing 14 changed files with 308 additions and 91 deletions.
170 changes: 93 additions & 77 deletions src/circom/input.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use input_processing::SimplificationStyle;
use std::path::{Path, PathBuf};

pub struct Input {
Expand Down Expand Up @@ -36,84 +35,51 @@ pub struct Input {
pub link_libraries: Vec<PathBuf>,
}

#[allow(dead_code)]
const R1CS: &str = "r1cs";
const WAT: &str = "wat";
const WASM: &str = "wasm";
const CPP: &str = "cpp";
const JS: &str = "js";
const DAT: &str = "dat";
const SYM: &str = "sym";
const JSON: &str = "json";

impl Input {
pub fn new(input: PathBuf, output_path: PathBuf) -> Result<Input, ()> {
let matches = input_processing::view();
let mut file_name = input.file_stem().unwrap().to_str().unwrap().to_string();

let c_flag = input_processing::get_c(&matches);

if c_flag && (file_name == "main" || file_name == "fr" || file_name == "calcwit") {
println!("The name {} is reserved in Circom when using de --c flag. The files generated for your circuit will use the name {}_c instead of {}.", file_name, file_name, file_name);
file_name = format!("{}_c", file_name)
pub fn new(input_file: &str, output_dir: &str, override_file_name: Option<&str>) -> Input {
let file_name = match override_file_name {
Some(f) => f.to_string(),
None => Path::new(input_file)
.file_stem()
.unwrap()
.to_string_lossy()
.to_string(),
};
let output_c_path = Input::build_folder(&output_path, &file_name, CPP);
let output_js_path = Input::build_folder(&output_path, &file_name, JS);
let o_style = input_processing::get_simplification_style(&matches)?;
let link_libraries = input_processing::get_link_libraries(&matches);
Result::Ok(Input {
//field: P_BN128,
input_program: input,
out_r1cs: Input::build_output(&output_path, &file_name, R1CS),
out_wat_code: Input::build_output(&output_js_path, &file_name, WAT),
out_wasm_code: Input::build_output(&output_js_path, &file_name, WASM),
out_js_folder: output_js_path.clone(),

Input {
input_program: input_file.into(),
out_r1cs: format!("{}/{}.r1cs", output_dir, file_name).into(),
out_json_constraints: format!("{}/{}_constraints.json", output_dir, file_name).into(),
out_json_substitutions: format!("{}/{}_substitutions.json", output_dir, file_name)
.into(),
out_wat_code: format!("{}/{}_js/{}.wat", output_dir, file_name, file_name).into(),
out_wasm_code: format!("{}/{}_js/{}.wasm", output_dir, file_name, file_name).into(),
out_wasm_name: file_name.clone(),
out_c_folder: output_c_path.clone(),
out_js_folder: format!("{}/{}_js", output_dir, file_name).into(),
out_c_run_name: file_name.clone(),
out_c_code: Input::build_output(&output_c_path, &file_name, CPP),
out_c_dat: Input::build_output(&output_c_path, &file_name, DAT),
out_sym: Input::build_output(&output_path, &file_name, SYM),
out_json_constraints: Input::build_output(
&output_path,
&format!("{}_constraints", file_name),
JSON,
),
out_json_substitutions: Input::build_output(
&output_path,
&format!("{}_substitutions", file_name),
JSON,
),
wat_flag: input_processing::get_wat(&matches),
wasm_flag: input_processing::get_wasm(&matches),
c_flag,
r1cs_flag: input_processing::get_r1cs(&matches),
sym_flag: input_processing::get_sym(&matches),
main_inputs_flag: input_processing::get_main_inputs_log(&matches),
json_constraint_flag: input_processing::get_json_constraints(&matches),
json_substitution_flag: input_processing::get_json_substitutions(&matches),
print_ir_flag: input_processing::get_ir(&matches),
no_rounds: if let SimplificationStyle::O2(r) = o_style {
r
} else {
0
},
fast_flag: o_style == SimplificationStyle::O0,
reduced_simplification_flag: o_style == SimplificationStyle::O1,
parallel_simplification_flag: input_processing::get_parallel_simplification(&matches),
inspect_constraints_flag: input_processing::get_inspect_constraints(&matches),
flag_old_heuristics: input_processing::get_flag_old_heuristics(&matches),
flag_verbose: input_processing::get_flag_verbose(&matches),
prime: input_processing::get_prime(&matches)?,
link_libraries,
})
}

fn build_folder(output_path: &Path, filename: &str, ext: &str) -> PathBuf {
let mut file = output_path.to_path_buf();
let folder_name = format!("{}_{}", filename, ext);
file.push(folder_name);
file
out_c_folder: format!("{}/{}_cpp", output_dir, file_name).into(),
out_c_code: format!("{}/{}_cpp/{}.cpp", output_dir, file_name, file_name).into(),
out_c_dat: format!("{}/{}_cpp/{}.dat", output_dir, file_name, file_name).into(),
out_sym: format!("{}/{}.sym", output_dir, file_name).into(),
c_flag: false,
wasm_flag: false,
wat_flag: false,
r1cs_flag: false,
sym_flag: false,
json_constraint_flag: false,
json_substitution_flag: false,
main_inputs_flag: false,
print_ir_flag: false,
fast_flag: false,
reduced_simplification_flag: false,
parallel_simplification_flag: false,
flag_old_heuristics: false,
inspect_constraints_flag: false,
no_rounds: 18446744073709551615,
flag_verbose: false,
prime: "bn128".into(),
link_libraries: vec![],
}
}

pub fn build_output(output_path: &Path, filename: &str, ext: &str) -> PathBuf {
Expand Down Expand Up @@ -219,12 +185,62 @@ impl Input {
self.prime.clone()
}
}

pub mod input_processing {
use clap::{App, Arg, ArgMatches};
use std::path::{Path, PathBuf};

use crate::circom::compilation::VERSION;
use clap::{App, Arg, ArgMatches};

use super::Input;

pub fn generate_input(input_file: PathBuf, output_dir: PathBuf) -> Result<Input, ()> {
let matches = view();
let mut file_name = input_file
.file_stem()
.unwrap()
.to_string_lossy()
.to_string();

let c_flag = get_c(&matches);

if c_flag && (file_name == "main" || file_name == "fr" || file_name == "calcwit") {
println!("The name {} is reserved in Circom when using de --c flag. The files generated for your circuit will use the name {}_c instead of {}.", file_name, file_name, file_name);
file_name = format!("{}_c", file_name)
};

let mut input = Input::new(
input_file.to_str().unwrap(),
output_dir.to_str().unwrap(),
Some(&file_name),
);

let o_style = get_simplification_style(&matches)?;

input.wat_flag = get_wat(&matches);
input.wasm_flag = get_wasm(&matches);
input.c_flag = c_flag;
input.r1cs_flag = get_r1cs(&matches);
input.sym_flag = get_sym(&matches);
input.main_inputs_flag = get_main_inputs_log(&matches);
input.json_constraint_flag = get_json_constraints(&matches);
input.json_substitution_flag = get_json_substitutions(&matches);
input.print_ir_flag = get_ir(&matches);
input.no_rounds = if let SimplificationStyle::O2(r) = o_style {
r
} else {
0
};
input.fast_flag = o_style == SimplificationStyle::O0;
input.reduced_simplification_flag = o_style == SimplificationStyle::O1;
input.parallel_simplification_flag = get_parallel_simplification(&matches);
input.inspect_constraints_flag = get_inspect_constraints(&matches);
input.flag_old_heuristics = get_flag_old_heuristics(&matches);
input.flag_verbose = get_flag_verbose(&matches);
input.prime = get_prime(&matches)?;
input.link_libraries = get_link_libraries(&matches);

Ok(input)
}

pub fn get_input(matches: &ArgMatches) -> Result<PathBuf, ()> {
let route = Path::new(matches.value_of("input").unwrap()).to_path_buf();
Expand Down
40 changes: 33 additions & 7 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,40 @@ pub enum AGateType {
ANeq,
ASub,
AXor,
APow,
AIntDiv,
AMod,
AShiftL,
AShiftR,
ABoolOr,
ABoolAnd,
ABitOr,
ABitAnd,
}

impl From<&ExpressionInfixOpcode> for AGateType {
fn from(opcode: &ExpressionInfixOpcode) -> Self {
match opcode {
ExpressionInfixOpcode::Add => AGateType::AAdd,
ExpressionInfixOpcode::Mul => AGateType::AMul,
ExpressionInfixOpcode::Div => AGateType::ADiv,
ExpressionInfixOpcode::Eq => AGateType::AEq,
ExpressionInfixOpcode::Greater => AGateType::AGt,
ExpressionInfixOpcode::Add => AGateType::AAdd,
ExpressionInfixOpcode::Sub => AGateType::ASub,
ExpressionInfixOpcode::Pow => AGateType::APow,
ExpressionInfixOpcode::IntDiv => AGateType::AIntDiv,
ExpressionInfixOpcode::Mod => AGateType::AMod,
ExpressionInfixOpcode::ShiftL => AGateType::AShiftL,
ExpressionInfixOpcode::ShiftR => AGateType::AShiftR,
ExpressionInfixOpcode::LesserEq => AGateType::ALEq,
ExpressionInfixOpcode::GreaterEq => AGateType::AGEq,
ExpressionInfixOpcode::Lesser => AGateType::ALt,
ExpressionInfixOpcode::LesserEq => AGateType::ALEq,
ExpressionInfixOpcode::Mul => AGateType::AMul,
ExpressionInfixOpcode::Greater => AGateType::AGt,
ExpressionInfixOpcode::Eq => AGateType::AEq,
ExpressionInfixOpcode::NotEq => AGateType::ANeq,
ExpressionInfixOpcode::Sub => AGateType::ASub,
ExpressionInfixOpcode::BoolOr => AGateType::ABoolOr,
ExpressionInfixOpcode::BoolAnd => AGateType::ABoolAnd,
ExpressionInfixOpcode::BitOr => AGateType::ABitOr,
ExpressionInfixOpcode::BitAnd => AGateType::ABitAnd,
ExpressionInfixOpcode::BitXor => AGateType::AXor,
_ => unimplemented!("Unsupported opcode"),
}
}
}
Expand All @@ -69,6 +86,15 @@ impl From<&AGateType> for Operation {
AGateType::AGt => Operation::GreaterThan,
AGateType::AGEq => Operation::GreaterOrEqual,
AGateType::AXor => Operation::XorBitwise,
AGateType::APow => Operation::Exponentiate,
AGateType::AIntDiv => Operation::IntegerDivide,
AGateType::AMod => Operation::Modulus,
AGateType::AShiftL => Operation::ShiftLeft,
AGateType::AShiftR => Operation::ShiftRight,
AGateType::ABoolOr => Operation::Or,
AGateType::ABoolAnd => Operation::And,
AGateType::ABitOr => Operation::OrBitwise,
AGateType::ABitAnd => Operation::AndBitwise,
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use circom_2_arithc::{
circom::input::{input_processing::view, Input},
circom::input::{
input_processing::{generate_input, view},
Input,
},
program::{build_circuit, ProgramError},
};
use dotenv::dotenv;
Expand All @@ -20,7 +23,7 @@ fn main() -> Result<(), ProgramError> {
fs::create_dir_all(output_path.clone())
.map_err(|_| ProgramError::OutputDirectoryCreationError)?;

let input = Input::new(PathBuf::from("./src/assets/circuit.circom"), output_path)
let input = generate_input(PathBuf::from("./src/assets/circuit.circom"), output_path)
.map_err(|_| ProgramError::InputInitializationError)?;
let output_dir = input
.out_r1cs
Expand Down
2 changes: 1 addition & 1 deletion tests/add_zero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const TEST_FILE_PATH: &str = "./tests/circuits/addZero.circom";

#[test]
fn test_add_zero() {
let input = Input::new(TEST_FILE_PATH.into(), "./".into()).unwrap();
let input = Input::new(TEST_FILE_PATH, "./", None);
let circuit = build_circuit(&input).unwrap();
let sim_circuit = circuit.build_sim_circuit().unwrap();

Expand Down
74 changes: 74 additions & 0 deletions tests/circuits/infixOps.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
pragma circom 2.0.0;

template infixOps() {
signal input x0;
signal input x1;
signal input x2;
signal input x3;
signal input x4;
signal input x5;

signal output mul_2_3;
signal output div_4_3;
signal output idiv_4_3;
signal output add_3_4;
signal output sub_4_1;
signal output pow_2_4;
signal output mod_5_3;
signal output shl_5_1;
signal output shr_5_1;
signal output leq_2_3;
signal output leq_3_3;
signal output leq_4_3;
signal output geq_2_3;
signal output geq_3_3;
signal output geq_4_3;
signal output lt_2_3;
signal output lt_3_3;
signal output lt_4_3;
signal output gt_2_3;
signal output gt_3_3;
signal output gt_4_3;
signal output eq_2_3;
signal output eq_3_3;
signal output neq_2_3;
signal output neq_3_3;
signal output or_0_1;
signal output and_0_1;
signal output bit_or_1_3;
signal output bit_and_1_3;
signal output bit_xor_1_3;

mul_2_3 <== x2 * x3;
div_4_3 <== x4 / x3;
idiv_4_3 <== x4 \ x3;
add_3_4 <== x3 + x4;
sub_4_1 <== x4 - x1;
pow_2_4 <== x2 ** x4;
mod_5_3 <== x5 % x3;
shl_5_1 <== x5 << x1;
shr_5_1 <== x5 >> x1;
leq_2_3 <== x2 <= x3;
leq_3_3 <== x3 <= x3;
leq_4_3 <== x4 <= x3;
geq_2_3 <== x2 >= x3;
geq_3_3 <== x3 >= x3;
geq_4_3 <== x4 >= x3;
lt_2_3 <== x2 < x3;
lt_3_3 <== x3 < x3;
lt_4_3 <== x4 < x3;
gt_2_3 <== x2 > x3;
gt_3_3 <== x3 > x3;
gt_4_3 <== x4 > x3;
eq_2_3 <== x2 == x3;
eq_3_3 <== x3 == x3;
neq_2_3 <== x2 != x3;
neq_3_3 <== x3 != x3;
or_0_1 <== x0 || x1;
and_0_1 <== x0 && x1;
bit_or_1_3 <== x1 | x3;
bit_and_1_3 <== x1 & x3;
bit_xor_1_3 <== x1 ^ x3;
}

component main = infixOps();
7 changes: 7 additions & 0 deletions tests/circuits/underConstrained.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pragma circom 2.0.0;

template underConstrained() {
signal output x;
}

component main = underConstrained();
10 changes: 10 additions & 0 deletions tests/circuits/xEqX.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
pragma circom 2.0.0;

template xEqX() {
signal input x;
signal output out;

out <== x == x;
}

component main = xEqX();
2 changes: 1 addition & 1 deletion tests/constant_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const TEST_FILE_PATH: &str = "./tests/circuits/constantSum.circom";

#[test]
fn test_constant_sum() {
let input = Input::new(TEST_FILE_PATH.into(), "./".into()).unwrap();
let input = Input::new(TEST_FILE_PATH, "./", None);
let circuit = build_circuit(&input).unwrap();
let sim_circuit = circuit.build_sim_circuit().unwrap();

Expand Down
Loading

0 comments on commit 5acb088

Please sign in to comment.