Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add NAC to the write path #3341

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions rust/blockstore/src/arrow/flusher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,15 @@ impl ArrowBlockfileFlusher {
// Flush all blocks in parallel using futures unordered
// NOTE(hammadb) we do not use try_join_all here because we want to flush all blocks
// in parallel and try_join_all / join_all switches to using futures_ordered if the
// number of futures is high. However, our NAC controls the number of futures that can be
// created at once, so that behavior is redudant and suboptimal for us.
// As of 10/28 the NAC does not impact the write path, only the read path.
// As a workaround we used buffered futures to reduce concurrency
// once the NAC supports write path admission control we can switch back
// to unbuffered futures.
// number of futures is high.

let mut futures = Vec::new();
for block in &self.blocks {
futures.push(self.block_manager.flush(block));
}
let num_futures = futures.len();
futures::stream::iter(futures)
.buffer_unordered(30)
.buffer_unordered(num_futures)
.try_collect::<Vec<_>>()
.await?;

Expand Down
122 changes: 113 additions & 9 deletions rust/storage/src/admissioncontrolleds3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ use crate::{
s3::{S3GetError, S3PutError, S3Storage},
};
use async_trait::async_trait;
use aws_sdk_s3::primitives::{ByteStream, Length};
use bytes::Bytes;
use chroma_config::Configurable;
use chroma_error::{ChromaError, ErrorCodes};
use futures::future::BoxFuture;
use futures::{future::Shared, stream, FutureExt, StreamExt};
use parking_lot::Mutex;
use std::ops::Range;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
use tokio::{
Expand All @@ -18,11 +22,14 @@ use tracing::{Instrument, Span};

/// Wrapper over s3 storage that provides proxy features such as
/// request coalescing, rate limiting, etc.
/// For reads, it will coalesce requests for the same key and rate limit
/// the number of concurrent requests.
/// For writes, it will rate limit the number of concurrent requests.
#[derive(Clone)]
pub struct AdmissionControlledS3Storage {
storage: S3Storage,
#[allow(clippy::type_complexity)]
outstanding_requests: Arc<
outstanding_read_requests: Arc<
Mutex<
HashMap<
String,
Expand Down Expand Up @@ -63,15 +70,15 @@ impl AdmissionControlledS3Storage {
pub fn new_with_default_policy(storage: S3Storage) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
outstanding_read_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new(2))),
}
}

pub fn new(storage: S3Storage, policy: RateLimitPolicy) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
outstanding_read_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(policy),
}
}
Expand Down Expand Up @@ -180,7 +187,7 @@ impl AdmissionControlledS3Storage {
// request to S3.
let future_to_await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
let maybe_inflight = requests.get(&key).cloned();
future_to_await = match maybe_inflight {
Some(fut) => {
Expand All @@ -203,7 +210,7 @@ impl AdmissionControlledS3Storage {

let res = future_to_await.await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
requests.remove(&key);
}
res
Expand All @@ -218,7 +225,7 @@ impl AdmissionControlledS3Storage {
// request to S3.
let future_to_await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
let maybe_inflight = requests.get(&key).cloned();
future_to_await = match maybe_inflight {
Some(fut) => fut,
Expand All @@ -238,18 +245,115 @@ impl AdmissionControlledS3Storage {

let res = future_to_await.await;
{
let mut requests = self.outstanding_requests.lock();
let mut requests = self.outstanding_read_requests.lock();
requests.remove(&key);
}
res
}

async fn oneshot_upload(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
// Acquire permit.
let _permit = self.rate_limiter.enter().await;
self.storage
.oneshot_upload(key, total_size_bytes, create_bytestream_fn)
.await
// Permit gets dropped due to RAII.
}

async fn multipart_upload(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
let (part_count, size_of_last_part, upload_id) = self
.storage
.prepare_multipart_upload(key, total_size_bytes)
.await?;
let mut upload_parts = Vec::new();
for part_index in 0..part_count {
// Acquire token.
let _permit = self.rate_limiter.enter().await;
let completed_part = self
.storage
.upload_part(
key,
&upload_id,
part_count,
part_index,
size_of_last_part,
&create_bytestream_fn,
)
.await?;
upload_parts.push(completed_part);
// Permit gets dropped due to RAII.
}

self.storage
.finish_multipart_upload(key, &upload_id, upload_parts)
.await
}

async fn put_object(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
if self.storage.is_oneshot_upload(total_size_bytes) {
return self
.oneshot_upload(key, total_size_bytes, create_bytestream_fn)
.await;
}

self.multipart_upload(key, total_size_bytes, create_bytestream_fn)
.await
}

pub async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> {
self.storage.put_file(key, path).await
let file_size = tokio::fs::metadata(path)
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?
.len();

let path = path.to_string();

self.put_object(key, file_size as usize, move |range| {
let path = path.clone();

async move {
ByteStream::read_from()
.path(path)
.offset(range.start as u64)
.length(Length::Exact(range.len() as u64))
.build()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))
}
.boxed()
})
.await
}

pub async fn put_bytes(&self, key: &str, bytes: Vec<u8>) -> Result<(), S3PutError> {
self.storage.put_bytes(key, bytes).await
let bytes = Arc::new(Bytes::from(bytes));

self.put_object(key, bytes.len(), move |range| {
let bytes = bytes.clone();
async move { Ok(ByteStream::from(bytes.slice(range))) }.boxed()
})
.await
}
}

Expand Down
128 changes: 88 additions & 40 deletions rust/storage/src/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ impl S3Storage {
}
}

pub(super) fn is_oneshot_upload(&self, total_size_bytes: usize) -> bool {
total_size_bytes < self.upload_part_size_bytes
}

pub async fn put_bytes(&self, key: &str, bytes: Vec<u8>) -> Result<(), S3PutError> {
let bytes = Arc::new(Bytes::from(bytes));

Expand Down Expand Up @@ -382,7 +386,7 @@ impl S3Storage {
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
if total_size_bytes < self.upload_part_size_bytes {
if self.is_oneshot_upload(total_size_bytes) {
return self
.oneshot_upload(key, total_size_bytes, create_bytestream_fn)
.await;
Expand All @@ -392,7 +396,7 @@ impl S3Storage {
.await
}

async fn oneshot_upload(
pub(super) async fn oneshot_upload(
&self,
key: &str,
total_size_bytes: usize,
Expand All @@ -412,14 +416,11 @@ impl S3Storage {
Ok(())
}

async fn multipart_upload(
pub(super) async fn prepare_multipart_upload(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
) -> Result<(usize, usize, String), S3PutError> {
let mut part_count = (total_size_bytes / self.upload_part_size_bytes) + 1;
let mut size_of_last_part = total_size_bytes % self.upload_part_size_bytes;
if size_of_last_part == 0 {
Expand All @@ -445,39 +446,55 @@ impl S3Storage {
}
};

let mut upload_parts = Vec::new();
for part_index in 0..part_count {
let this_part = if part_count - 1 == part_index {
size_of_last_part
} else {
self.upload_part_size_bytes
};
let part_number = part_index as i32 + 1; // Part numbers start at 1
let offset = part_index * self.upload_part_size_bytes;
let length = this_part;

let stream = create_bytestream_fn(offset..(offset + length)).await?;

let upload_part_res = self
.client
.upload_part()
.key(key)
.bucket(&self.bucket)
.upload_id(&upload_id)
.body(stream)
.part_number(part_number)
.send()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?;

upload_parts.push(
CompletedPart::builder()
.e_tag(upload_part_res.e_tag.unwrap_or_default())
.part_number(part_number)
.build(),
);
}
Ok((part_count, size_of_last_part, upload_id))
}

pub(super) async fn upload_part(
&self,
key: &str,
upload_id: &str,
part_count: usize,
part_index: usize,
size_of_last_part: usize,
create_bytestream_fn: &impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<CompletedPart, S3PutError> {
let this_part = if part_count - 1 == part_index {
size_of_last_part
} else {
self.upload_part_size_bytes
};
let part_number = part_index as i32 + 1; // Part numbers start at 1
let offset = part_index * self.upload_part_size_bytes;
let length = this_part;

let stream = create_bytestream_fn(offset..(offset + length)).await?;

let upload_part_res = self
.client
.upload_part()
.key(key)
.bucket(&self.bucket)
.upload_id(upload_id)
.body(stream)
.part_number(part_number)
.send()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?;

Ok(CompletedPart::builder()
.e_tag(upload_part_res.e_tag.unwrap_or_default())
.part_number(part_number)
.build())
}

pub(super) async fn finish_multipart_upload(
&self,
key: &str,
upload_id: &str,
upload_parts: Vec<CompletedPart>,
) -> Result<(), S3PutError> {
self.client
.complete_multipart_upload()
.bucket(&self.bucket)
Expand All @@ -487,13 +504,44 @@ impl S3Storage {
.set_parts(Some(upload_parts))
.build(),
)
.upload_id(&upload_id)
.upload_id(upload_id)
.send()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?;

Ok(())
}

async fn multipart_upload(
&self,
key: &str,
total_size_bytes: usize,
create_bytestream_fn: impl Fn(
Range<usize>,
) -> BoxFuture<'static, Result<ByteStream, S3PutError>>,
) -> Result<(), S3PutError> {
let (part_count, size_of_last_part, upload_id) =
self.prepare_multipart_upload(key, total_size_bytes).await?;

let mut upload_parts = Vec::new();
for part_index in 0..part_count {
let completed_part = self
.upload_part(
key,
&upload_id,
part_count,
part_index,
size_of_last_part,
&create_bytestream_fn,
)
.await?;

upload_parts.push(completed_part);
}

self.finish_multipart_upload(key, &upload_id, upload_parts)
.await
}
}

#[async_trait]
Expand Down
Loading
Loading