Skip to content

Commit

Permalink
Fix dataset decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Sicheng Pan committed Dec 17, 2024
1 parent c024cb1 commit 9df14a7
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 78 deletions.
83 changes: 52 additions & 31 deletions rust/benchmark/src/datasets/sift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct Sift1MData {
impl Sift1MData {
pub async fn init() -> Result<Self> {
let base = get_or_populate_cached_dataset_file(
"gist",
"sift1m",
"base.fvecs",
None,
|mut writer| async move {
Expand All @@ -45,7 +45,7 @@ impl Sift1MData {
},
).await?;
let query = get_or_populate_cached_dataset_file(
"gist",
"sift1m",
"query.fvecs",
None,
|mut writer| async move {
Expand All @@ -70,8 +70,8 @@ impl Sift1MData {
},
).await?;
let ground = get_or_populate_cached_dataset_file(
"gist",
"groundtruth.fvecs",
"sift1m",
"groundtruth.ivecs",
None,
|mut writer| async move {
let client = reqwest::Client::new();
Expand Down Expand Up @@ -134,43 +134,64 @@ impl Sift1MData {
return Ok(Vec::new());
}

let start = SeekFrom::Start(
(size_of::<u32>() + lower_bound * Self::dimension() * size_of::<f32>()) as u64,
);
let vector_size = size_of::<u32>() + Self::dimension() * size_of::<f32>();

let start = SeekFrom::Start((lower_bound * vector_size) as u64);
self.base.seek(start).await?;
let batch_size = upper_bound - lower_bound;
let mut base_bytes = vec![0; batch_size * Self::dimension() * size_of::<f32>()];
let mut base_bytes = vec![0; batch_size * vector_size];
self.base.read_exact(&mut base_bytes).await?;
let embedding_f32s: Vec<_> = base_bytes
.chunks(size_of::<f32>())
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect();
Ok(embedding_f32s
.chunks(Self::dimension())
.map(|embedding| embedding.to_vec())
.collect())
read_raw_vec(&base_bytes, |bytes| {
Ok(f32::from_le_bytes(bytes.try_into()?))
})
}

pub async fn query(&mut self) -> Result<Vec<(Vec<f32>, Vec<u32>)>> {
let mut query_bytes = Vec::new();
self.query.read_to_end(&mut query_bytes).await?;
let (_, embeddings_bytes) = query_bytes.split_at(size_of::<u32>());
let embedding_f32s: Vec<_> = embeddings_bytes
.chunks(size_of::<f32>())
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect();
let queries = read_raw_vec(&query_bytes, |bytes| {
Ok(f32::from_le_bytes(bytes.try_into()?))
})?;

let mut ground_bytes = Vec::new();
self.ground.read_to_end(&mut ground_bytes).await?;
let (_, embeddings_bytes) = query_bytes.split_at(size_of::<u32>());
let ground_u32s: Vec<_> = embeddings_bytes
.chunks(size_of::<u32>())
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
.collect();
Ok(embedding_f32s
.chunks(Self::dimension())
.zip(ground_u32s.chunks(Self::k()))
.map(|(embedding, ground)| (embedding.to_vec(), ground.to_vec()))
.collect())
let grounds = read_raw_vec(&ground_bytes, |bytes| {
Ok(u32::from_le_bytes(bytes.try_into()?))
})?;
if queries.len() != grounds.len() {
return Err(anyhow!(
"Queries and grounds count mismatch: {} != {}",
queries.len(),
grounds.len()
));
}
Ok(queries.into_iter().zip(grounds).collect())
}
}

fn read_raw_vec<T>(
raw_bytes: &[u8],
convert_from_bytes: impl Fn(&[u8]) -> Result<T>,
) -> Result<Vec<Vec<T>>> {
let mut result = Vec::new();
let mut bytes = raw_bytes;
while !bytes.is_empty() {
let (dimension_bytes, rem_bytes) = bytes.split_at(size_of::<u32>());
let dimension = u32::from_le_bytes(dimension_bytes.try_into()?);
let (embedding_bytes, rem_bytes) = rem_bytes.split_at(dimension as usize * size_of::<T>());
let embedding = embedding_bytes
.chunks(size_of::<T>())
.map(&convert_from_bytes)
.collect::<Result<Vec<T>>>()?;
if embedding.len() != dimension as usize {
return Err(anyhow!(
"Embedding dimension mismatch: {} != {}",
embedding.len(),
dimension
));
}
result.push(embedding);
bytes = rem_bytes;
}
Ok(result)
}
67 changes: 50 additions & 17 deletions rust/worker/benches/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread};
use chroma_config::Configurable;
use criterion::{criterion_group, criterion_main, Criterion};
use load::{
all_projection, always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit,
sift1m_segments, trivial_filter, trivial_limit, trivial_projection,
all_projection, always_false_filter_for_modulo_metadata,
always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit, sift1m_segments,
trivial_filter, trivial_limit, trivial_projection,
};
use worker::{
config::RootConfig,
Expand All @@ -33,7 +34,25 @@ fn trivial_get(
)
}

fn get_filter(
fn get_false_filter(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
) -> GetOrchestrator {
let blockfile_provider = test_segments.blockfile_provider.clone();
let collection_uuid = test_segments.collection.collection_id;
GetOrchestrator::new(
blockfile_provider,
dispatcher_handle,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
always_false_filter_for_modulo_metadata(),
trivial_limit(),
trivial_projection(),
)
}

fn get_true_filter(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
) -> GetOrchestrator {
Expand All @@ -51,7 +70,7 @@ fn get_filter(
)
}

fn get_filter_limit(
fn get_true_filter_limit(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
) -> GetOrchestrator {
Expand All @@ -69,7 +88,7 @@ fn get_filter_limit(
)
}

fn get_filter_limit_projection(
fn get_true_filter_limit_projection(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
) -> GetOrchestrator {
Expand Down Expand Up @@ -123,24 +142,31 @@ fn bench_get(criterion: &mut Criterion) {
(0..100).map(|id| id.to_string()).collect(),
)
};
let get_filter_setup = || {
let get_false_filter_setup = || {
(
system.clone(),
get_false_filter(test_segments.clone(), dispatcher_handle.clone()),
Vec::new(),
)
};
let get_true_filter_setup = || {
(
system.clone(),
get_filter(test_segments.clone(), dispatcher_handle.clone()),
get_true_filter(test_segments.clone(), dispatcher_handle.clone()),
(0..100).map(|id| id.to_string()).collect(),
)
};
let get_filter_limit_setup = || {
let get_true_filter_limit_setup = || {
(
system.clone(),
get_filter_limit(test_segments.clone(), dispatcher_handle.clone()),
get_true_filter_limit(test_segments.clone(), dispatcher_handle.clone()),
(100..200).map(|id| id.to_string()).collect(),
)
};
let get_filter_limit_projection_setup = || {
let get_true_filter_limit_projection_setup = || {
(
system.clone(),
get_filter_limit_projection(test_segments.clone(), dispatcher_handle.clone()),
get_true_filter_limit_projection(test_segments.clone(), dispatcher_handle.clone()),
(100..200).map(|id| id.to_string()).collect(),
)
};
Expand All @@ -153,24 +179,31 @@ fn bench_get(criterion: &mut Criterion) {
bench_routine,
);
bench_run(
"test-get-filter",
"test-get-false-filter",
criterion,
&runtime,
get_false_filter_setup,
bench_routine,
);
bench_run(
"test-get-true-filter",
criterion,
&runtime,
get_filter_setup,
get_true_filter_setup,
bench_routine,
);
bench_run(
"test-get-filter-limit",
"test-get-true-filter-limit",
criterion,
&runtime,
get_filter_limit_setup,
get_true_filter_limit_setup,
bench_routine,
);
bench_run(
"test-get-filter-limit-projection",
"test-get-true-filter-limit-projection",
criterion,
&runtime,
get_filter_limit_projection_setup,
get_true_filter_limit_projection_setup,
bench_routine,
);
}
Expand Down
22 changes: 22 additions & 0 deletions rust/worker/benches/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,28 @@ pub fn trivial_filter() -> FilterOperator {
}
}

pub fn always_false_filter_for_modulo_metadata() -> FilterOperator {
FilterOperator {
query_ids: None,
where_clause: Some(Where::disjunction(vec![
Where::DirectWhereComparison(DirectWhereComparison {
key: "is_even".to_string(),
comparison: WhereComparison::Set(
SetOperator::NotIn,
MetadataSetValue::Bool(vec![false, true]),
),
}),
Where::DirectWhereComparison(DirectWhereComparison {
key: "modulo_3".to_string(),
comparison: WhereComparison::Set(
SetOperator::NotIn,
MetadataSetValue::Int(vec![0, 1, 2]),
),
}),
])),
}
}

pub fn always_true_filter_for_modulo_metadata() -> FilterOperator {
FilterOperator {
query_ids: None,
Expand Down
Loading

0 comments on commit 9df14a7

Please sign in to comment.