Skip to content

Commit

Permalink
Rewrite partitions and run_by_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
j178 committed Dec 10, 2024
1 parent 00406c0 commit 5cc66b3
Showing 1 changed file with 88 additions and 65 deletions.
153 changes: 88 additions & 65 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,102 @@ use std::cmp::max;
use std::future::Future;
use std::sync::{Arc, LazyLock};

use tokio::task::JoinSet;
use futures::StreamExt;
use tracing::trace;

use crate::hook::Hook;

pub static CONCURRENCY: LazyLock<usize> = LazyLock::new(|| target_concurrency(false));

fn target_concurrency(serial: bool) -> usize {
if serial || std::env::var_os("PRE_COMMIT_NO_CONCURRENCY").is_some() {
pub static CONCURRENCY: LazyLock<usize> = LazyLock::new(|| {
if std::env::var_os("PRE_COMMIT_NO_CONCURRENCY").is_some() {
1
} else {
std::thread::available_parallelism()
.map(std::num::NonZero::get)
.unwrap_or(1)
}
});

fn target_concurrency(serial: bool) -> usize {
if serial {
1
} else {
*CONCURRENCY
}
}

// TODO: do a more accurate calculation
fn partitions<'a>(
/// Iterator that yields partitions of filenames that fit within the maximum command line length.
struct Partitions<'a> {
hook: &'a Hook,
filenames: &'a [&String],
filenames: &'a [&'a String],
concurrency: usize,
) -> Vec<Vec<&'a String>> {
// If there are no filenames, we still want to run the hook once.
if filenames.is_empty() {
return vec![vec![]];
}
current_index: usize,
command_length: usize,
max_per_batch: usize,
max_cli_length: usize,
}

let max_per_batch = max(4, filenames.len().div_ceil(concurrency));
// TODO: subtract the env size
let max_cli_length = if cfg!(unix) {
1 << 12
} else {
(1 << 15) - 2048 // UNICODE_STRING max - headroom
};

let command_length =
hook.entry.len() + hook.args.iter().map(String::len).sum::<usize>() + hook.args.len();

let mut partitions = Vec::new();
let mut current = Vec::new();
let mut current_length = command_length + 1;

for &filename in filenames {
let length = filename.len() + 1;
if current_length + length > max_cli_length || current.len() >= max_per_batch {
partitions.push(current);
current = Vec::new();
current_length = command_length + 1;
// TODO: do a more accurate calculation
impl<'a> Partitions<'a> {
fn new(hook: &'a Hook, filenames: &'a [&'a String], concurrency: usize) -> Self {
let max_per_batch = max(4, filenames.len().div_ceil(concurrency));
// TODO: subtract the env size
let max_cli_length = if cfg!(unix) {
1 << 12
} else {
(1 << 15) - 2048 // UNICODE_STRING max - headroom
};
let command_length =
hook.entry.len() + hook.args.iter().map(String::len).sum::<usize>() + hook.args.len();

Self {
hook,
filenames,
concurrency,
current_index: 0,
command_length,
max_per_batch,
max_cli_length,
}
current.push(filename);
current_length += length;
}
}

if !current.is_empty() {
partitions.push(current);
}
impl<'a> Iterator for Partitions<'a> {
type Item = Vec<&'a String>;

fn next(&mut self) -> Option<Self::Item> {
// Handle empty filenames case
if self.filenames.is_empty() && self.current_index == 0 {
self.current_index = 1;
return Some(vec![]);
}

if self.current_index >= self.filenames.len() {
return None;
}

let mut current = Vec::new();
let mut current_length = self.command_length + 1;

while self.current_index < self.filenames.len() {
let filename = self.filenames[self.current_index];
let length = filename.len() + 1;

if current_length + length > self.max_cli_length || current.len() >= self.max_per_batch
{
break;
}

current.push(filename);
current_length += length;
self.current_index += 1;
}

partitions
if current.is_empty() {
None
} else {
Some(current)
}
}
}

pub async fn run_by_batch<T, F, Fut>(
Expand All @@ -74,44 +111,30 @@ where
Fut: Future<Output = anyhow::Result<T>> + Send + 'static,
T: Send + 'static,
{
let mut concurrency = target_concurrency(hook.require_serial);
let concurrency = target_concurrency(hook.require_serial);

// Split files into batches
let partitions = partitions(hook, filenames, concurrency);
concurrency = concurrency.min(partitions.len());
let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency));
let partitions = Partitions::new(hook, filenames, concurrency);
trace!(
total_files = filenames.len(),
partitions = partitions.len(),
concurrency = concurrency,
"Running {}",
hook.id,
);

let run = Arc::new(run);

// Spawn tasks for each batch
let mut tasks = JoinSet::new();

for batch in partitions {
let semaphore = semaphore.clone();
let run = run.clone();

let batch: Vec<_> = batch.into_iter().map(ToString::to_string).collect();

tasks.spawn(async move {
let _permit = semaphore
.acquire()
.await
.map_err(|_| anyhow::anyhow!("Failed to acquire semaphore"))?;

run(batch).await
});
}
let mut tasks = futures::stream::iter(partitions)
.map(|batch| {
let run = run.clone();
let batch: Vec<_> = batch.into_iter().map(ToString::to_string).collect();
run(batch)
})
.buffer_unordered(concurrency);

let mut results = Vec::new();
while let Some(result) = tasks.join_next().await {
results.push(result??);
while let Some(result) = tasks.next().await {
results.push(result?);
}

Ok(results)
Expand Down

0 comments on commit 5cc66b3

Please sign in to comment.