From e565cea69608190f7ea4e0770db76b75caae3404 Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Sat, 4 Nov 2023 14:44:39 -0700 Subject: [PATCH] fabric: executor: Remove recursion to avoid stack overflows Previously we recursively called `handle_new_result` from the executor method to avoid brokering recursive evaluation through the slower job queue. For very deep circuits, this can lead to excessive recursion and a stack overflow. Here we refactor the recursive approach into an iterative one from a dispatch method. --- .github/workflows/test.yml | 2 +- src/fabric.rs | 24 ++++++++++++++ src/fabric/executor.rs | 68 +++++++++++++++++++++++++++----------- 3 files changed, 73 insertions(+), 21 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 60dde6e..e6e5ccc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,6 +11,6 @@ jobs: steps: - uses: actions/checkout@v3 - name: Build - run: cargo build --workspace --verbose + run: cargo build --workspace --all-features --verbose - name: Run tests run: cargo test --lib --all-features --verbose diff --git a/src/fabric.rs b/src/fabric.rs index aedcd7d..2555333 100644 --- a/src/fabric.rs +++ b/src/fabric.rs @@ -1121,3 +1121,27 @@ impl MpcFabric { AuthenticatedScalarResult::new_shared_batch(&bits) } } + +#[cfg(test)] +mod test { + use crate::{algebra::Scalar, test_helpers::execute_mock_mpc, PARTY0}; + + /// Tests a linear circuit of very large depth + #[tokio::test] + async fn test_deep_circuit() { + const DEPTH: usize = 1_000_000; + let (res, _) = execute_mock_mpc(|fabric| async move { + // Perform an operation that takes time, so that further operations will enqueue + // behind it + let mut res = fabric.share_plaintext(Scalar::from(1u8), PARTY0); + for _ in 0..DEPTH { + res = res + fabric.one(); + } + + res.await + }) + .await; + + assert_eq!(res, Scalar::from(DEPTH + 1)); + } +} diff --git a/src/fabric/executor.rs b/src/fabric/executor.rs index cadb035..147421b 100644 --- a/src/fabric/executor.rs +++ b/src/fabric/executor.rs @@ -20,7 +20,7 @@ use crate::network::NetworkOutbound; use super::result::ResultWaiter; use super::{result::OpResult, FabricInner}; -use super::{Operation, OperationType, ResultId}; +use super::{Operation, OperationId, OperationType, ResultId}; // --------- // | Stats | @@ -214,12 +214,27 @@ impl Executor { /// Handle a new result fn handle_new_result(&mut self, result: OpResult) { let id = result.id; + self.insert_result(result); + + // Execute all operations that are ready after committing this result + let mut ops_queue = Vec::new(); + self.append_ready_ops(id, &mut ops_queue); + self.execute_operations(ops_queue); + } + + /// Insert a result into the buffer + fn insert_result(&mut self, result: OpResult) { let prev = self.results.insert(result.id, result); - assert!(prev.is_none(), "duplicate result id: {id:?}"); + assert!( + prev.is_none(), + "duplicate result id: {:?}", + prev.unwrap().id + ); + } - // Execute any ready dependencies + /// Get the operations that are ready for execution after a result comes in + fn append_ready_ops(&mut self, id: OperationId, ready_ops: &mut Vec) { if let Some(deps) = self.dependencies.get(id) { - let mut ready_ops = Vec::new(); for op_id in deps.iter() { let operation = self.operations.get_mut(*op_id).unwrap(); @@ -231,14 +246,7 @@ impl Executor { // Mark the operation as ready for execution ready_ops.push(*op_id); } - - for op in ready_ops.into_iter() { - let op = self.operations.take(op).unwrap(); - self.execute_operation(op); - } } - - self.wake_waiters_on_result(id); } /// Handle a new operation @@ -264,7 +272,10 @@ impl Executor { // If the operation is ready for execution, do so if inflight_args == 0 { - self.execute_operation(op); + let id = op.id; + self.operations.insert(id, op); + + self.execute_operations(vec![id]); return; } @@ -289,8 +300,25 @@ impl Executor { self.stats.new_operation(op, is_network_op); } - /// Executes an operation whose arguments are ready - fn execute_operation(&mut self, op: Operation) { + /// Executes the operations in the buffer, recursively executing any + /// dependencies that become ready + fn execute_operations(&mut self, mut ops: Vec) { + while let Some(op_id) = ops.pop() { + let op = self.operations.take(op_id).unwrap(); + let res = self.compute_result(op); + + for result in res.into_iter() { + let id = result.id; + + self.append_ready_ops(result.id, &mut ops); + self.insert_result(result); + self.wake_waiters_on_result(id); + } + } + } + + /// Compute the result of an operation + fn compute_result(&mut self, op: Operation) -> Vec> { let result_ids = op.result_ids(); // Collect the inputs to the operation @@ -303,10 +331,10 @@ impl Executor { match op.op_type { OperationType::Gate { function } => { let value = (function)(inputs); - self.handle_new_result(OpResult { + vec![OpResult { id: op.result_id, value, - }); + }] }, OperationType::GateBatch { function } => { @@ -315,7 +343,7 @@ impl Executor { .into_iter() .zip(output) .map(|(id, value)| OpResult { id, value }) - .for_each(|res| self.handle_new_result(res)); + .collect() }, OperationType::Network { function } => { @@ -336,12 +364,12 @@ impl Executor { // On a `send`, the local party receives a copy of the value placed as the // result of the network operation, so we must re-enqueue the // result - self.handle_new_result(OpResult { + vec![OpResult { id: result_id, value: payload.into(), - }); + }] }, - }; + } } /// Handle a new waiter for a result