Skip to content

Commit

Permalink
[TST] Benchmark query node (#3320)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
   - N/A
 - New functionality
   - Implement benchmark that runs the entire get/knn orchestrator

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
Sicheng-Pan authored Dec 17, 2024
1 parent da0858d commit d833e53
Show file tree
Hide file tree
Showing 33 changed files with 995 additions and 133 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions rust/benchmark/src/datasets/gist.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::path::PathBuf;

use anyhow::Ok;
use anyhow::Result;
use tokio::io::{AsyncReadExt, BufReader};

use super::{
Expand All @@ -20,7 +20,7 @@ impl RecordDataset for GistDataset {
const DISPLAY_NAME: &'static str = "Gist";
const NAME: &'static str = "gist";

async fn init() -> anyhow::Result<Self> {
async fn init() -> Result<Self> {
// TODO(Sanket): Download file if it doesn't exist.
// move file from downloads to cached path.
let current_path = "/Users/sanketkedia/Downloads/siftsmall/siftsmall_base.fvecs";
Expand Down
1 change: 1 addition & 0 deletions rust/benchmark/src/datasets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pub mod util;
pub mod gist;
pub mod ms_marco_queries;
pub mod scidocs;
pub mod sift;
pub mod wikipedia;
197 changes: 197 additions & 0 deletions rust/benchmark/src/datasets/sift.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
use std::{
io::SeekFrom,
ops::{Bound, RangeBounds},
};

use anyhow::{anyhow, Ok, Result};
use tokio::{
fs::File,
io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader},
};

use super::util::get_or_populate_cached_dataset_file;

pub struct Sift1MData {
pub base: BufReader<File>,
pub query: BufReader<File>,
pub ground: BufReader<File>,
}

impl Sift1MData {
pub async fn init() -> Result<Self> {
let base = get_or_populate_cached_dataset_file(
"sift1m",
"base.fvecs",
None,
|mut writer| async move {
let client = reqwest::Client::new();
let response = client
.get(
"https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_base.fvecs",
)
.send()
.await?;

if !response.status().is_success() {
return Err(anyhow!(
"Failed to download Sift1M base data, got status code {}",
response.status()
));
}

writer.write_all(&response.bytes().await?).await?;

Ok(())
},
).await?;
let query = get_or_populate_cached_dataset_file(
"sift1m",
"query.fvecs",
None,
|mut writer| async move {
let client = reqwest::Client::new();
let response = client
.get(
"https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_query.fvecs",
)
.send()
.await?;

if !response.status().is_success() {
return Err(anyhow!(
"Failed to download Sift1M query data, got status code {}",
response.status()
));
}

writer.write_all(&response.bytes().await?).await?;

Ok(())
},
).await?;
let ground = get_or_populate_cached_dataset_file(
"sift1m",
"groundtruth.ivecs",
None,
|mut writer| async move {
let client = reqwest::Client::new();
let response = client
.get(
"https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_groundtruth.ivecs",
)
.send()
.await?;

if !response.status().is_success() {
return Err(anyhow!(
"Failed to download Sift1M ground data, got status code {}",
response.status()
));
}

writer.write_all(&response.bytes().await?).await?;

Ok(())
},
).await?;
Ok(Self {
base: BufReader::new(File::open(base).await?),
query: BufReader::new(File::open(query).await?),
ground: BufReader::new(File::open(ground).await?),
})
}

pub fn collection_size() -> usize {
1000000
}

pub fn query_size() -> usize {
10000
}

pub fn dimension() -> usize {
128
}

pub fn k() -> usize {
100
}

pub async fn data_range(&mut self, range: impl RangeBounds<usize>) -> Result<Vec<Vec<f32>>> {
let lower_bound = match range.start_bound() {
Bound::Included(include) => *include,
Bound::Excluded(exclude) => exclude + 1,
Bound::Unbounded => 0,
};
let upper_bound = match range.end_bound() {
Bound::Included(include) => include + 1,
Bound::Excluded(exclude) => *exclude,
Bound::Unbounded => usize::MAX,
}
.min(Self::collection_size());

if lower_bound >= upper_bound {
return Ok(Vec::new());
}

let vector_size = size_of::<u32>() + Self::dimension() * size_of::<f32>();

let start = SeekFrom::Start((lower_bound * vector_size) as u64);
self.base.seek(start).await?;
let batch_size = upper_bound - lower_bound;
let mut base_bytes = vec![0; batch_size * vector_size];
self.base.read_exact(&mut base_bytes).await?;
read_raw_vec(&base_bytes, |bytes| {
Ok(f32::from_le_bytes(bytes.try_into()?))
})
}

pub async fn query(&mut self) -> Result<Vec<(Vec<f32>, Vec<u32>)>> {
let mut query_bytes = Vec::new();
self.query.read_to_end(&mut query_bytes).await?;
let queries = read_raw_vec(&query_bytes, |bytes| {
Ok(f32::from_le_bytes(bytes.try_into()?))
})?;

let mut ground_bytes = Vec::new();
self.ground.read_to_end(&mut ground_bytes).await?;
let grounds = read_raw_vec(&ground_bytes, |bytes| {
Ok(u32::from_le_bytes(bytes.try_into()?))
})?;
if queries.len() != grounds.len() {
return Err(anyhow!(
"Queries and grounds count mismatch: {} != {}",
queries.len(),
grounds.len()
));
}
Ok(queries.into_iter().zip(grounds).collect())
}
}

fn read_raw_vec<T>(
raw_bytes: &[u8],
convert_from_bytes: impl Fn(&[u8]) -> Result<T>,
) -> Result<Vec<Vec<T>>> {
let mut result = Vec::new();
let mut bytes = raw_bytes;
while !bytes.is_empty() {
let (dimension_bytes, rem_bytes) = bytes.split_at(size_of::<u32>());
let dimension = u32::from_le_bytes(dimension_bytes.try_into()?);
let (embedding_bytes, rem_bytes) = rem_bytes.split_at(dimension as usize * size_of::<T>());
let embedding = embedding_bytes
.chunks(size_of::<T>())
.map(&convert_from_bytes)
.collect::<Result<Vec<T>>>()?;
if embedding.len() != dimension as usize {
return Err(anyhow!(
"Embedding dimension mismatch: {} != {}",
embedding.len(),
dimension
));
}
result.push(embedding);
bytes = rem_bytes;
}
Ok(result)
}
17 changes: 13 additions & 4 deletions rust/worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,33 @@ chroma-distance = { workspace = true }
random-port = "0.1.1"
serial_test = "3.2.0"

criterion = { workspace = true }
indicatif = { workspace = true }
proptest = { workspace = true }
proptest-state-machine = { workspace = true }
shuttle = { workspace = true }
rand = { workspace = true }
rand_xorshift = { workspace = true }
tempfile = { workspace = true }
shuttle = { workspace = true }
proptest = { workspace = true }
proptest-state-machine = { workspace = true }
criterion = { workspace = true }

chroma-benchmark = { workspace = true }

[[bench]]
name = "filter"
harness = false

[[bench]]
name = "get"
harness = false

[[bench]]
name = "limit"
harness = false

[[bench]]
name = "query"
harness = false

[[bench]]
name = "spann"
harness = false
9 changes: 4 additions & 5 deletions rust/worker/benches/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use criterion::Criterion;
use criterion::{criterion_group, criterion_main};
use worker::execution::operator::Operator;
use worker::execution::operators::filter::{FilterInput, FilterOperator};
use worker::log::test::{upsert_generator, LogGenerator};
use worker::log::test::upsert_generator;
use worker::segment::test::TestSegment;

fn baseline_where_clauses() -> Vec<(&'static str, Option<Where>)> {
Expand Down Expand Up @@ -71,14 +71,13 @@ fn baseline_where_clauses() -> Vec<(&'static str, Option<Where>)> {

fn bench_filter(criterion: &mut Criterion) {
let runtime = tokio_multi_thread();
let logen = LogGenerator {
generator: upsert_generator,
};

for record_count in [1000, 10000, 100000] {
let test_segment = runtime.block_on(async {
let mut segment = TestSegment::default();
segment.populate_with_generator(record_count, &logen).await;
segment
.populate_with_generator(record_count, upsert_generator)
.await;
segment
});

Expand Down
Loading

0 comments on commit d833e53

Please sign in to comment.