Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add NAC to the write path #3341

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions rust/blockstore/src/arrow/flusher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,20 @@ impl ArrowBlockfileFlusher {
// Flush all blocks in parallel using futures unordered
// NOTE(hammadb) we do not use try_join_all here because we want to flush all blocks
// in parallel and try_join_all / join_all switches to using futures_ordered if the
// number of futures is high. However, our NAC controls the number of futures that can be
// created at once, so that behavior is redudant and suboptimal for us.
// As of 10/28 the NAC does not impact the write path, only the read path.
// As a workaround we used buffered futures to reduce concurrency
// once the NAC supports write path admission control we can switch back
// to unbuffered futures.

// number of futures is high.
let mut futures = Vec::new();
for block in &self.blocks {
futures.push(self.block_manager.flush(block));
}
let num_futures = futures.len();
// buffer_unordered hangs with 0 futures.
if num_futures == 0 {
self.root_manager.flush::<K>(&self.root).await?;
return Ok(());
}
tracing::debug!("Flushing {} blocks", num_futures);
futures::stream::iter(futures)
.buffer_unordered(30)
.buffer_unordered(num_futures)
.try_collect::<Vec<_>>()
.await?;

Expand Down
122 changes: 113 additions & 9 deletions rust/storage/src/admissioncontrolleds3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ use crate::{
s3::{S3GetError, S3PutError, S3Storage},
};
use async_trait::async_trait;
use aws_sdk_s3::primitives::{ByteStream, Length};
use bytes::Bytes;
use chroma_config::Configurable;
use chroma_error::{ChromaError, ErrorCodes};
use futures::future::BoxFuture;
use futures::{future::Shared, stream, FutureExt, StreamExt};
use parking_lot::Mutex;
use std::ops::Range;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
use tokio::{
Expand All @@ -18,11 +22,14 @@ use tracing::{Instrument, Span};

/// Wrapper over s3 storage that provides proxy features such as
/// request coalescing, rate limiting, etc.
/// For reads, it will coalesce requests for the same key and rate limit
/// the number of concurrent requests.
/// For writes, it will rate limit the number of concurrent requests.
#[derive(Clone)]
pub struct AdmissionControlledS3Storage {
storage: S3Storage,
#[allow(clippy::type_complexity)]
outstanding_requests: Arc<
outstanding_read_requests: Arc<
Mutex<
HashMap<
String,
Expand Down Expand Up @@ -63,15 +70,15 @@ impl AdmissionControlledS3Storage {
pub fn new_with_default_policy(storage: S3Storage) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
outstanding_read_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new(2))),
}
}

pub fn new(storage: S3Storage, policy: RateLimitPolicy) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
outstanding_read_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(policy),
}
}
Expand Down Expand Up @@ -180,7 +187,7 @@ impl AdmissionControlledS3Storage {
// request to S3.
let future_to_await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
let maybe_inflight = requests.get(&key).cloned();
future_to_await = match maybe_inflight {
Some(fut) => {
Expand All @@ -203,7 +210,7 @@ impl AdmissionControlledS3Storage {

let res = future_to_await.await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
requests.remove(&key);
}
res
Expand All @@ -218,7 +225,7 @@ impl AdmissionControlledS3Storage {
// request to S3.
let future_to_await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
let maybe_inflight = requests.get(&key).cloned();
future_to_await = match maybe_inflight {
Some(fut) => fut,
Expand All @@ -238,18 +245,115 @@ impl AdmissionControlledS3Storage {

let res = future_to_await.await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
requests.remove(&key);
}
res
}

async fn oneshot_upload(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
// Acquire permit.
let _permit = self.rate_limiter.enter().await;
self.storage
.oneshot_upload(key, total_size_bytes, create_bytestream_fn)
.await
// Permit gets dropped due to RAII.
}

async fn multipart_upload(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
let (part_count, size_of_last_part, upload_id) = self
.storage
.prepare_multipart_upload(key, total_size_bytes)
.await?;
let mut upload_parts = Vec::new();
for part_index in 0..part_count {
// Acquire token.
let _permit = self.rate_limiter.enter().await;
let completed_part = self
.storage
.upload_part(
key,
&upload_id,
part_count,
part_index,
size_of_last_part,
&create_bytestream_fn,
)
.await?;
upload_parts.push(completed_part);
// Permit gets dropped due to RAII.
}

self.storage
.finish_multipart_upload(key, &upload_id, upload_parts)
.await
}

async fn put_object(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
if self.storage.is_oneshot_upload(total_size_bytes) {
return self
.oneshot_upload(key, total_size_bytes, create_bytestream_fn)
.await;
}

self.multipart_upload(key, total_size_bytes, create_bytestream_fn)
.await
}

pub async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> {
self.storage.put_file(key, path).await
let file_size = tokio::fs::metadata(path)
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?
.len();

let path = path.to_string();

self.put_object(key, file_size as usize, move |range| {
let path = path.clone();

async move {
ByteStream::read_from()
.path(path)
.offset(range.start as u64)
.length(Length::Exact(range.len() as u64))
.build()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))
}
.boxed()
})
.await
}

pub async fn put_bytes(&self, key: &str, bytes: Vec<u8>) -> Result<(), S3PutError> {
self.storage.put_bytes(key, bytes).await
let bytes = Arc::new(Bytes::from(bytes));

self.put_object(key, bytes.len(), move |range| {
let bytes = bytes.clone();
async move { Ok(ByteStream::from(bytes.slice(range))) }.boxed()
})
.await
}
}

Expand Down
Loading
Loading