Skip to content

Commit

Permalink
fabric: executor: Remove recursion to avoid stack overflows
Browse files Browse the repository at this point in the history
Previously we recursively called `handle_new_result` from the executor
method to avoid brokering recursive evaluation through the slower
job queue. For very deep circuits, this can lead to excessive recursion
and a stack overflow.

Here we refactor the recursive approach into an iterative one from
a dispatch method.
  • Loading branch information
joeykraut committed Nov 4, 2023
1 parent 0c22cb0 commit e565cea
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Build
run: cargo build --workspace --verbose
run: cargo build --workspace --all-features --verbose
- name: Run tests
run: cargo test --lib --all-features --verbose
24 changes: 24 additions & 0 deletions src/fabric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,27 @@ impl<C: CurveGroup> MpcFabric<C> {
AuthenticatedScalarResult::new_shared_batch(&bits)
}
}

#[cfg(test)]
mod test {
use crate::{algebra::Scalar, test_helpers::execute_mock_mpc, PARTY0};

/// Tests a linear circuit of very large depth
#[tokio::test]
async fn test_deep_circuit() {
const DEPTH: usize = 1_000_000;
let (res, _) = execute_mock_mpc(|fabric| async move {
// Perform an operation that takes time, so that further operations will enqueue
// behind it
let mut res = fabric.share_plaintext(Scalar::from(1u8), PARTY0);
for _ in 0..DEPTH {
res = res + fabric.one();
}

res.await
})
.await;

assert_eq!(res, Scalar::from(DEPTH + 1));
}
}
68 changes: 48 additions & 20 deletions src/fabric/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::network::NetworkOutbound;

use super::result::ResultWaiter;
use super::{result::OpResult, FabricInner};
use super::{Operation, OperationType, ResultId};
use super::{Operation, OperationId, OperationType, ResultId};

// ---------
// | Stats |
Expand Down Expand Up @@ -214,12 +214,27 @@ impl<C: CurveGroup> Executor<C> {
/// Handle a new result
fn handle_new_result(&mut self, result: OpResult<C>) {
let id = result.id;
self.insert_result(result);

// Execute all operations that are ready after committing this result
let mut ops_queue = Vec::new();
self.append_ready_ops(id, &mut ops_queue);
self.execute_operations(ops_queue);
}

/// Insert a result into the buffer
fn insert_result(&mut self, result: OpResult<C>) {
let prev = self.results.insert(result.id, result);
assert!(prev.is_none(), "duplicate result id: {id:?}");
assert!(
prev.is_none(),
"duplicate result id: {:?}",
prev.unwrap().id
);
}

// Execute any ready dependencies
/// Get the operations that are ready for execution after a result comes in
fn append_ready_ops(&mut self, id: OperationId, ready_ops: &mut Vec<OperationId>) {
if let Some(deps) = self.dependencies.get(id) {
let mut ready_ops = Vec::new();
for op_id in deps.iter() {
let operation = self.operations.get_mut(*op_id).unwrap();

Expand All @@ -231,14 +246,7 @@ impl<C: CurveGroup> Executor<C> {
// Mark the operation as ready for execution
ready_ops.push(*op_id);
}

for op in ready_ops.into_iter() {
let op = self.operations.take(op).unwrap();
self.execute_operation(op);
}
}

self.wake_waiters_on_result(id);
}

/// Handle a new operation
Expand All @@ -264,7 +272,10 @@ impl<C: CurveGroup> Executor<C> {

// If the operation is ready for execution, do so
if inflight_args == 0 {
self.execute_operation(op);
let id = op.id;
self.operations.insert(id, op);

self.execute_operations(vec![id]);
return;
}

Expand All @@ -289,8 +300,25 @@ impl<C: CurveGroup> Executor<C> {
self.stats.new_operation(op, is_network_op);
}

/// Executes an operation whose arguments are ready
fn execute_operation(&mut self, op: Operation<C>) {
/// Executes the operations in the buffer, recursively executing any
/// dependencies that become ready
fn execute_operations(&mut self, mut ops: Vec<OperationId>) {
while let Some(op_id) = ops.pop() {
let op = self.operations.take(op_id).unwrap();
let res = self.compute_result(op);

for result in res.into_iter() {
let id = result.id;

self.append_ready_ops(result.id, &mut ops);
self.insert_result(result);
self.wake_waiters_on_result(id);
}
}
}

/// Compute the result of an operation
fn compute_result(&mut self, op: Operation<C>) -> Vec<OpResult<C>> {
let result_ids = op.result_ids();

// Collect the inputs to the operation
Expand All @@ -303,10 +331,10 @@ impl<C: CurveGroup> Executor<C> {
match op.op_type {
OperationType::Gate { function } => {
let value = (function)(inputs);
self.handle_new_result(OpResult {
vec![OpResult {
id: op.result_id,
value,
});
}]
},

OperationType::GateBatch { function } => {
Expand All @@ -315,7 +343,7 @@ impl<C: CurveGroup> Executor<C> {
.into_iter()
.zip(output)
.map(|(id, value)| OpResult { id, value })
.for_each(|res| self.handle_new_result(res));
.collect()
},

OperationType::Network { function } => {
Expand All @@ -336,12 +364,12 @@ impl<C: CurveGroup> Executor<C> {
// On a `send`, the local party receives a copy of the value placed as the
// result of the network operation, so we must re-enqueue the
// result
self.handle_new_result(OpResult {
vec![OpResult {
id: result_id,
value: payload.into(),
});
}]
},
};
}
}

/// Handle a new waiter for a result
Expand Down

0 comments on commit e565cea

Please sign in to comment.