Skip to content

Commit

Permalink
Cache extracted loop body functions for reuse (#131)
Browse files Browse the repository at this point in the history
Also:
* Refactor storing of global data for extracted loop functions
* Refactor context kind into the EnvRecorder
* Remove unnecessary derives
* Rename some things for clarity
  • Loading branch information
tim-hoffman authored Jul 10, 2024
1 parent 7279b84 commit bb24637
Show file tree
Hide file tree
Showing 8 changed files with 491 additions and 165 deletions.
247 changes: 247 additions & 0 deletions circom/tests/controlflow/multiuse_func_with_loop_same.circom

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,7 @@ impl Display for UnrolledBlockEnvData<'_> {
impl LibraryAccess for UnrolledBlockEnvData<'_> {
fn get_function(&self, name: &String) -> Ref<FunctionCode> {
if name.starts_with(LOOP_BODY_FN_PREFIX) {
Ref::map(self.extractor.get_new_functions(), |f| {
f.iter()
.find(|f| f.header.eq(name))
.expect("Cannot find extracted function definition!")
})
self.extractor.search_new_functions(name)
} else {
self.base.get_function(name)
}
Expand Down
1 change: 1 addition & 0 deletions circuit_passes/src/passes/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ macro_rules! checked_insert {
($map: expr, $key: expr, $val: expr) => {{
let key = $key;
let val = $val;
#[allow(unused_mut)] // some callers may already pass a &mut and that causes warning here
let mut map = $map;
assert!(
!map.contains_key(&key) || map[&key] == val,
Expand Down
342 changes: 220 additions & 122 deletions circuit_passes/src/passes/loop_unroll/body_extractor.rs

Large diffs are not rendered by default.

32 changes: 10 additions & 22 deletions circuit_passes/src/passes/loop_unroll/loop_env_recorder.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use std::cell::{RefCell, Ref};
use std::collections::{BTreeMap, HashSet, HashMap};
use std::collections::{BTreeMap, HashMap};
use std::fmt::{Debug, Formatter};
use indexmap::IndexMap;
use compiler::intermediate_representation::BucketId;
use compiler::intermediate_representation::ir_interface::*;
use crate::bucket_interpreter::env::Env;
use crate::bucket_interpreter::env::{Env, EnvContextKind};
use crate::bucket_interpreter::error::BadInterp;
use crate::bucket_interpreter::memory::PassMemory;
use crate::bucket_interpreter::observer::Observer;
use crate::bucket_interpreter::value::Value;
use crate::passes::GlobalPassData;
use super::DEBUG_LOOP_UNROLL;
use super::body_extractor::{UnrolledIterLvars, ToOriginalLocation, FuncArgIdx};

pub struct EnvRecorder<'a, 'd> {
global_data: &'d RefCell<GlobalPassData>,
pub global_data: &'d RefCell<GlobalPassData>,
mem: &'a PassMemory,
pub(crate) ctx_kind: EnvContextKind,
// NOTE: RefCell is needed here because the instance of this struct is borrowed by
// the main interpreter while we also need to mutate these internal structures.
current_iter_num: RefCell<usize>,
Expand Down Expand Up @@ -55,10 +55,15 @@ impl Debug for EnvRecorder<'_, '_> {
}

impl<'a, 'd> EnvRecorder<'a, 'd> {
pub fn new(global_data: &'d RefCell<GlobalPassData>, mem: &'a PassMemory) -> Self {
pub fn new(
global_data: &'d RefCell<GlobalPassData>,
mem: &'a PassMemory,
ctx_kind: EnvContextKind,
) -> Self {
EnvRecorder {
global_data,
mem,
ctx_kind,
current_iter_num: RefCell::new(0),
safe_to_move: RefCell::new(true),
loadstore_to_index_per_iter: Default::default(),
Expand Down Expand Up @@ -110,23 +115,6 @@ impl<'a, 'd> EnvRecorder<'a, 'd> {
self.env_at_header.replace(None);
}

pub fn record_reverse_arg_mapping(
&self,
extract_func: String,
iter_env: UnrolledIterLvars,
value: (ToOriginalLocation, HashSet<FuncArgIdx>),
) {
if DEBUG_LOOP_UNROLL {
println!("[EnvRecorder] stored data {:?} -> {:?}", iter_env, value);
}
self.global_data
.borrow_mut()
.extract_func_orig_loc
.entry(extract_func)
.or_default()
.insert(iter_env, value);
}

#[inline]
fn default_return(&self) -> Result<bool, BadInterp> {
Ok(self.is_safe_to_move()) //continue observing unless something unsafe has been found
Expand Down
15 changes: 5 additions & 10 deletions circuit_passes/src/passes/loop_unroll/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl<'d> LoopUnrollPass<'d> {
println!("[UNROLL] LOOP ENTRY env {}", env);
}
// Compute loop iteration count. If unknown, return immediately.
let recorder = EnvRecorder::new(self.global_data, &self.memory);
let recorder = EnvRecorder::new(self.global_data, &self.memory, env.get_context_kind());
{
let interpreter = self.memory.build_interpreter(self.global_data, &recorder);
let mut inner_env = env.clone();
Expand Down Expand Up @@ -151,12 +151,7 @@ impl<'d> LoopUnrollPass<'d> {
if DEBUG_LOOP_UNROLL {
println!("[UNROLL][try_unroll_loop] OUTCOME: safe to move, extracting");
}
self.extractor.extract(
bucket,
recorder,
env.get_context_kind(),
&mut block_body,
)?;
self.extractor.extract(bucket, recorder, &mut block_body)?;
}
}
} else {
Expand Down Expand Up @@ -193,7 +188,7 @@ impl Observer<Env<'_>> for LoopUnrollPass<'_> {
if DEBUG_LOOP_UNROLL {
println!("[UNROLL][try_unroll_loop] result = {:?}", result);
}
// Add the loop bucket to the ordering for the before visiting within via continue_inside()
// Add the loop bucket to the ordering before visiting within via continue_inside()
// so that outer loop iteration counts appear first in the new function name
self.loop_bucket_order.borrow_mut().insert(bucket.id);
//
Expand Down Expand Up @@ -251,9 +246,9 @@ impl CircuitTransformationPass for LoopUnrollPass<'_> {

fn post_hook_circuit(&self, cir: &mut Circuit) -> Result<(), BadInterp> {
// Transform and add the new body functions from the extractor
let new_funcs = self.extractor.get_new_functions();
let new_funcs = self.extractor.take_new_functions();
cir.functions.reserve_exact(new_funcs.len());
for f in new_funcs.iter() {
for f in new_funcs {
cir.functions.insert(0, self.transform_function(&f)?);
}
// Add the duplicated versions of functions created by transform_call_bucket()
Expand Down
13 changes: 7 additions & 6 deletions circuit_passes/src/passes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,15 +766,16 @@ pub enum PassKind {
MappedToIndexed,
UnknownIndexSanitization,
}
/// Maps UnrolledIterLvars (from Env::get_vars_sort) to a pair containing:
/// (1) location references from the original function, used by ExtractedFuncEnvData to
/// access the original function's Env via the extracted function's parameter references
/// (2) the set of parameters that contain subcomponent arenas
pub type ExtractedFuncData = BTreeMap<UnrolledIterLvars, (ToOriginalLocation, HashSet<FuncArgIdx>)>;

#[derive(Debug)]
pub struct GlobalPassData {
/// Created during loop unrolling, maps generated function name + UnrolledIterLvars
/// (from Env::get_vars_sort) to location reference in the original function. Used
/// by ExtractedFuncEnvData to access the original function's Env via the extracted
/// function's parameter references.
extract_func_orig_loc:
HashMap<String, BTreeMap<UnrolledIterLvars, (ToOriginalLocation, HashSet<FuncArgIdx>)>>,
/// Created during loop unrolling, maps generated function name to ExtractedFuncData for it.
extract_func_orig_loc: HashMap<String, ExtractedFuncData>,
}

impl GlobalPassData {
Expand Down

0 comments on commit bb24637

Please sign in to comment.