diff --git a/src/run.rs b/src/run.rs index d68e62d..ff4008d 100644 --- a/src/run.rs +++ b/src/run.rs @@ -2,65 +2,102 @@ use std::cmp::max; use std::future::Future; use std::sync::{Arc, LazyLock}; -use tokio::task::JoinSet; +use futures::StreamExt; use tracing::trace; use crate::hook::Hook; -pub static CONCURRENCY: LazyLock = LazyLock::new(|| target_concurrency(false)); - -fn target_concurrency(serial: bool) -> usize { - if serial || std::env::var_os("PRE_COMMIT_NO_CONCURRENCY").is_some() { +pub static CONCURRENCY: LazyLock = LazyLock::new(|| { + if std::env::var_os("PRE_COMMIT_NO_CONCURRENCY").is_some() { 1 } else { std::thread::available_parallelism() .map(std::num::NonZero::get) .unwrap_or(1) } +}); + +fn target_concurrency(serial: bool) -> usize { + if serial { + 1 + } else { + *CONCURRENCY + } } -// TODO: do a more accurate calculation -fn partitions<'a>( +/// Iterator that yields partitions of filenames that fit within the maximum command line length. +struct Partitions<'a> { hook: &'a Hook, - filenames: &'a [&String], + filenames: &'a [&'a String], concurrency: usize, -) -> Vec> { - // If there are no filenames, we still want to run the hook once. - if filenames.is_empty() { - return vec![vec![]]; - } + current_index: usize, + command_length: usize, + max_per_batch: usize, + max_cli_length: usize, +} - let max_per_batch = max(4, filenames.len().div_ceil(concurrency)); - // TODO: subtract the env size - let max_cli_length = if cfg!(unix) { - 1 << 12 - } else { - (1 << 15) - 2048 // UNICODE_STRING max - headroom - }; - - let command_length = - hook.entry.len() + hook.args.iter().map(String::len).sum::() + hook.args.len(); - - let mut partitions = Vec::new(); - let mut current = Vec::new(); - let mut current_length = command_length + 1; - - for &filename in filenames { - let length = filename.len() + 1; - if current_length + length > max_cli_length || current.len() >= max_per_batch { - partitions.push(current); - current = Vec::new(); - current_length = command_length + 1; +// TODO: do a more accurate calculation +impl<'a> Partitions<'a> { + fn new(hook: &'a Hook, filenames: &'a [&'a String], concurrency: usize) -> Self { + let max_per_batch = max(4, filenames.len().div_ceil(concurrency)); + // TODO: subtract the env size + let max_cli_length = if cfg!(unix) { + 1 << 12 + } else { + (1 << 15) - 2048 // UNICODE_STRING max - headroom + }; + let command_length = + hook.entry.len() + hook.args.iter().map(String::len).sum::() + hook.args.len(); + + Self { + hook, + filenames, + concurrency, + current_index: 0, + command_length, + max_per_batch, + max_cli_length, } - current.push(filename); - current_length += length; } +} - if !current.is_empty() { - partitions.push(current); - } +impl<'a> Iterator for Partitions<'a> { + type Item = Vec<&'a String>; + + fn next(&mut self) -> Option { + // Handle empty filenames case + if self.filenames.is_empty() && self.current_index == 0 { + self.current_index = 1; + return Some(vec![]); + } + + if self.current_index >= self.filenames.len() { + return None; + } + + let mut current = Vec::new(); + let mut current_length = self.command_length + 1; + + while self.current_index < self.filenames.len() { + let filename = self.filenames[self.current_index]; + let length = filename.len() + 1; + + if current_length + length > self.max_cli_length || current.len() >= self.max_per_batch + { + break; + } + + current.push(filename); + current_length += length; + self.current_index += 1; + } - partitions + if current.is_empty() { + None + } else { + Some(current) + } + } } pub async fn run_by_batch( @@ -74,15 +111,12 @@ where Fut: Future> + Send + 'static, T: Send + 'static, { - let mut concurrency = target_concurrency(hook.require_serial); + let concurrency = target_concurrency(hook.require_serial); // Split files into batches - let partitions = partitions(hook, filenames, concurrency); - concurrency = concurrency.min(partitions.len()); - let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency)); + let partitions = Partitions::new(hook, filenames, concurrency); trace!( total_files = filenames.len(), - partitions = partitions.len(), concurrency = concurrency, "Running {}", hook.id, @@ -90,28 +124,17 @@ where let run = Arc::new(run); - // Spawn tasks for each batch - let mut tasks = JoinSet::new(); - - for batch in partitions { - let semaphore = semaphore.clone(); - let run = run.clone(); - - let batch: Vec<_> = batch.into_iter().map(ToString::to_string).collect(); - - tasks.spawn(async move { - let _permit = semaphore - .acquire() - .await - .map_err(|_| anyhow::anyhow!("Failed to acquire semaphore"))?; - - run(batch).await - }); - } + let mut tasks = futures::stream::iter(partitions) + .map(|batch| { + let run = run.clone(); + let batch: Vec<_> = batch.into_iter().map(ToString::to_string).collect(); + run(batch) + }) + .buffer_unordered(concurrency); let mut results = Vec::new(); - while let Some(result) = tasks.join_next().await { - results.push(result??); + while let Some(result) = tasks.next().await { + results.push(result?); } Ok(results)