Skip to content

Commit

Permalink
[ENH] Add NAC to the write path (#3341)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
   - Adds NAC to the write path. The NAC rate limits the number of outstanding writes.

## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
None
  • Loading branch information
sanketkedia authored Dec 20, 2024
1 parent e279b76 commit 4488279
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 58 deletions.
17 changes: 9 additions & 8 deletions rust/blockstore/src/arrow/flusher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,20 @@ impl ArrowBlockfileFlusher {
// Flush all blocks in parallel using futures unordered
// NOTE(hammadb) we do not use try_join_all here because we want to flush all blocks
// in parallel and try_join_all / join_all switches to using futures_ordered if the
// number of futures is high. However, our NAC controls the number of futures that can be
// created at once, so that behavior is redudant and suboptimal for us.
// As of 10/28 the NAC does not impact the write path, only the read path.
// As a workaround we used buffered futures to reduce concurrency
// once the NAC supports write path admission control we can switch back
// to unbuffered futures.

// number of futures is high.
let mut futures = Vec::new();
for block in &self.blocks {
futures.push(self.block_manager.flush(block));
}
let num_futures = futures.len();
// buffer_unordered hangs with 0 futures.
if num_futures == 0 {
self.root_manager.flush::<K>(&self.root).await?;
return Ok(());
}
tracing::debug!("Flushing {} blocks", num_futures);
futures::stream::iter(futures)
.buffer_unordered(30)
.buffer_unordered(num_futures)
.try_collect::<Vec<_>>()
.await?;

Expand Down
122 changes: 113 additions & 9 deletions rust/storage/src/admissioncontrolleds3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ use crate::{
s3::{S3GetError, S3PutError, S3Storage},
};
use async_trait::async_trait;
use aws_sdk_s3::primitives::{ByteStream, Length};
use bytes::Bytes;
use chroma_config::Configurable;
use chroma_error::{ChromaError, ErrorCodes};
use futures::future::BoxFuture;
use futures::{future::Shared, stream, FutureExt, StreamExt};
use parking_lot::Mutex;
use std::ops::Range;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
use tokio::{
Expand All @@ -18,11 +22,14 @@ use tracing::{Instrument, Span};

/// Wrapper over s3 storage that provides proxy features such as
/// request coalescing, rate limiting, etc.
/// For reads, it will coalesce requests for the same key and rate limit
/// the number of concurrent requests.
/// For writes, it will rate limit the number of concurrent requests.
#[derive(Clone)]
pub struct AdmissionControlledS3Storage {
storage: S3Storage,
#[allow(clippy::type_complexity)]
outstanding_requests: Arc<
outstanding_read_requests: Arc<
Mutex<
HashMap<
String,
Expand Down Expand Up @@ -63,15 +70,15 @@ impl AdmissionControlledS3Storage {
pub fn new_with_default_policy(storage: S3Storage) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
outstanding_read_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new(2))),
}
}

pub fn new(storage: S3Storage, policy: RateLimitPolicy) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
outstanding_read_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(policy),
}
}
Expand Down Expand Up @@ -180,7 +187,7 @@ impl AdmissionControlledS3Storage {
// request to S3.
let future_to_await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
let maybe_inflight = requests.get(&key).cloned();
future_to_await = match maybe_inflight {
Some(fut) => {
Expand All @@ -203,7 +210,7 @@ impl AdmissionControlledS3Storage {

let res = future_to_await.await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
requests.remove(&key);
}
res
Expand All @@ -218,7 +225,7 @@ impl AdmissionControlledS3Storage {
// request to S3.
let future_to_await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
let maybe_inflight = requests.get(&key).cloned();
future_to_await = match maybe_inflight {
Some(fut) => fut,
Expand All @@ -238,18 +245,115 @@ impl AdmissionControlledS3Storage {

let res = future_to_await.await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
requests.remove(&key);
}
res
}

async fn oneshot_upload(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
// Acquire permit.
let _permit = self.rate_limiter.enter().await;
self.storage
.oneshot_upload(key, total_size_bytes, create_bytestream_fn)
.await
// Permit gets dropped due to RAII.
}

async fn multipart_upload(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
let (part_count, size_of_last_part, upload_id) = self
.storage
.prepare_multipart_upload(key, total_size_bytes)
.await?;
let mut upload_parts = Vec::new();
for part_index in 0..part_count {
// Acquire token.
let _permit = self.rate_limiter.enter().await;
let completed_part = self
.storage
.upload_part(
key,
&upload_id,
part_count,
part_index,
size_of_last_part,
&create_bytestream_fn,
)
.await?;
upload_parts.push(completed_part);
// Permit gets dropped due to RAII.
}

self.storage
.finish_multipart_upload(key, &upload_id, upload_parts)
.await
}

async fn put_object(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
if self.storage.is_oneshot_upload(total_size_bytes) {
return self
.oneshot_upload(key, total_size_bytes, create_bytestream_fn)
.await;
}

self.multipart_upload(key, total_size_bytes, create_bytestream_fn)
.await
}

pub async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> {
self.storage.put_file(key, path).await
let file_size = tokio::fs::metadata(path)
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?
.len();

let path = path.to_string();

self.put_object(key, file_size as usize, move |range| {
let path = path.clone();

async move {
ByteStream::read_from()
.path(path)
.offset(range.start as u64)
.length(Length::Exact(range.len() as u64))
.build()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))
}
.boxed()
})
.await
}

pub async fn put_bytes(&self, key: &str, bytes: Vec<u8>) -> Result<(), S3PutError> {
self.storage.put_bytes(key, bytes).await
let bytes = Arc::new(Bytes::from(bytes));

self.put_object(key, bytes.len(), move |range| {
let bytes = bytes.clone();
async move { Ok(ByteStream::from(bytes.slice(range))) }.boxed()
})
.await
}
}

Expand Down
Loading

0 comments on commit 4488279

Please sign in to comment.