Skip to content

Commit

Permalink
[ENH] Introduce spann segment reader (#3212)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
   - Spann segment reader with from_segment impl
 - New functionality
   - ...

## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
None
  • Loading branch information
sanketkedia authored Dec 5, 2024
1 parent f824329 commit 65b5e01
Show file tree
Hide file tree
Showing 5 changed files with 477 additions and 12 deletions.
33 changes: 33 additions & 0 deletions rust/blockstore/src/memory/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,39 @@ impl<'referred_data> Readable<'referred_data> for DataRecord<'referred_data> {
}
}

impl<'referred_data> Readable<'referred_data> for SpannPostingList<'referred_data> {
fn read_from_storage(_: &str, _: KeyWrapper, _: &'referred_data Storage) -> Option<Self> {
todo!()
}

fn read_range_from_storage<'prefix, PrefixRange, KeyRange>(
_: PrefixRange,
_: KeyRange,
_: &'referred_data Storage,
) -> Vec<(&'referred_data CompositeKey, Self)>
where
PrefixRange: std::ops::RangeBounds<&'prefix str>,
KeyRange: std::ops::RangeBounds<KeyWrapper>,
{
todo!()
}

fn get_at_index(
_: &'referred_data Storage,
_: usize,
) -> Option<(&'referred_data CompositeKey, Self)> {
todo!()
}

fn count(_: &Storage) -> Result<usize, Box<dyn ChromaError>> {
todo!()
}

fn contains(_: &str, _: KeyWrapper, _: &'referred_data Storage) -> bool {
todo!()
}
}

#[derive(Clone)]
pub struct StorageBuilder {
bool_storage: Arc<RwLock<Option<BTreeMap<CompositeKey, bool>>>>,
Expand Down
6 changes: 6 additions & 0 deletions rust/blockstore/src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ impl Value for &DataRecord<'_> {
}
}

impl Value for SpannPostingList<'_> {
fn get_size(&self) -> usize {
self.compute_size()
}
}

impl Value for &SpannPostingList<'_> {
fn get_size(&self) -> usize {
self.compute_size()
Expand Down
148 changes: 138 additions & 10 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{

use chroma_blockstore::{
provider::{BlockfileProvider, CreateError, OpenError},
BlockfileFlusher, BlockfileWriter, BlockfileWriterOptions,
BlockfileFlusher, BlockfileReader, BlockfileWriter, BlockfileWriterOptions,
};
use chroma_distance::{normalize, DistanceFunction};
use chroma_error::{ChromaError, ErrorCodes};
Expand Down Expand Up @@ -1427,6 +1427,124 @@ impl SpannIndexFlusher {
}
}

#[derive(Error, Debug)]
pub enum SpannIndexReaderError {
#[error("Error creating/opening hnsw index")]
HnswIndexConstructionError,
#[error("Error creating/opening blockfile reader")]
BlockfileReaderConstructionError,
#[error("Spann index uninitialized")]
UninitializedIndex,
}

impl ChromaError for SpannIndexReaderError {
fn code(&self) -> ErrorCodes {
match self {
Self::HnswIndexConstructionError => ErrorCodes::Internal,
Self::BlockfileReaderConstructionError => ErrorCodes::Internal,
Self::UninitializedIndex => ErrorCodes::Internal,
}
}
}

#[derive(Clone)]
pub struct SpannIndexReader<'me> {
pub posting_lists: BlockfileReader<'me, u32, SpannPostingList<'me>>,
pub hnsw_index: HnswIndexRef,
pub versions_map: BlockfileReader<'me, u32, u32>,
}

impl<'me> SpannIndexReader<'me> {
async fn hnsw_index_from_id(
hnsw_provider: &HnswIndexProvider,
id: &IndexUuid,
cache_key: &CollectionUuid,
distance_function: DistanceFunction,
dimensionality: usize,
) -> Result<HnswIndexRef, SpannIndexReaderError> {
match hnsw_provider.get(id, cache_key).await {
Some(index) => Ok(index),
None => {
match hnsw_provider
.open(id, cache_key, dimensionality as i32, distance_function)
.await
{
Ok(index) => Ok(index),
Err(_) => Err(SpannIndexReaderError::HnswIndexConstructionError),
}
}
}
}

async fn posting_list_reader_from_id(
blockfile_id: &Uuid,
blockfile_provider: &BlockfileProvider,
) -> Result<BlockfileReader<'me, u32, SpannPostingList<'me>>, SpannIndexReaderError> {
match blockfile_provider
.read::<u32, SpannPostingList<'me>>(blockfile_id)
.await
{
Ok(reader) => Ok(reader),
Err(_) => Err(SpannIndexReaderError::BlockfileReaderConstructionError),
}
}

async fn versions_map_reader_from_id(
blockfile_id: &Uuid,
blockfile_provider: &BlockfileProvider,
) -> Result<BlockfileReader<'me, u32, u32>, SpannIndexReaderError> {
match blockfile_provider.read::<u32, u32>(blockfile_id).await {
Ok(reader) => Ok(reader),
Err(_) => Err(SpannIndexReaderError::BlockfileReaderConstructionError),
}
}

#[allow(clippy::too_many_arguments)]
pub async fn from_id(
hnsw_id: Option<&IndexUuid>,
hnsw_provider: &HnswIndexProvider,
hnsw_cache_key: &CollectionUuid,
distance_function: DistanceFunction,
dimensionality: usize,
pl_blockfile_id: Option<&Uuid>,
versions_map_blockfile_id: Option<&Uuid>,
blockfile_provider: &BlockfileProvider,
) -> Result<SpannIndexReader<'me>, SpannIndexReaderError> {
let hnsw_reader = match hnsw_id {
Some(hnsw_id) => {
Self::hnsw_index_from_id(
hnsw_provider,
hnsw_id,
hnsw_cache_key,
distance_function,
dimensionality,
)
.await?
}
None => {
return Err(SpannIndexReaderError::UninitializedIndex);
}
};
let postings_list_reader = match pl_blockfile_id {
Some(pl_id) => Self::posting_list_reader_from_id(pl_id, blockfile_provider).await?,
None => return Err(SpannIndexReaderError::UninitializedIndex),
};

let versions_map_reader = match versions_map_blockfile_id {
Some(versions_id) => {
Self::versions_map_reader_from_id(versions_id, blockfile_provider).await?
}
None => return Err(SpannIndexReaderError::UninitializedIndex),
};

Ok(Self {
posting_lists: postings_list_reader,
hnsw_index: hnsw_reader,
versions_map: versions_map_reader,
})
}
}

#[cfg(test)]
mod tests {
use std::{f32::consts::PI, path::PathBuf};
Expand Down Expand Up @@ -1556,22 +1674,32 @@ mod tests {
{
// Posting list should have 100 points.
let pl_read_guard = writer.posting_list_writer.lock().await;
let pl = pl_read_guard
let pl1 = pl_read_guard
.get_owned::<u32, &SpannPostingList<'_>>("", emb_1_id)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 100);
assert_eq!(pl.1.len(), 100);
assert_eq!(pl.2.len(), 200);
let pl = pl_read_guard
let pl2 = pl_read_guard
.get_owned::<u32, &SpannPostingList<'_>>("", emb_2_id)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 1);
assert_eq!(pl.1.len(), 1);
assert_eq!(pl.2.len(), 2);
// Only two combinations possible.
if pl1.0.len() == 100 {
assert_eq!(pl1.1.len(), 100);
assert_eq!(pl1.2.len(), 200);
assert_eq!(pl2.0.len(), 1);
assert_eq!(pl2.1.len(), 1);
assert_eq!(pl2.2.len(), 2);
} else if pl2.0.len() == 100 {
assert_eq!(pl2.1.len(), 100);
assert_eq!(pl2.2.len(), 200);
assert_eq!(pl1.0.len(), 1);
assert_eq!(pl1.1.len(), 1);
assert_eq!(pl1.2.len(), 2);
} else {
panic!("Invalid posting list lengths");
}
}
// Next insert 99 points in the region of (1000.0, 1000.0)
for i in 102..=200 {
Expand Down Expand Up @@ -1911,7 +2039,7 @@ mod tests {
version_map_guard.versions_map.insert(100 + point as u32, 1);
}
}
// Delete 60 points each from the centers. Since merge_threshold is 40, this should
// Delete 60 points each from the centers. Since merge_threshold is 50, this should
// trigger a merge between the two centers.
for point in 1..=60 {
writer
Expand Down
1 change: 1 addition & 0 deletions rust/types/src/spann_posting_list.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[derive(Clone, Debug)]
pub struct SpannPostingList<'referred_data> {
pub doc_offset_ids: &'referred_data [u32],
pub doc_versions: &'referred_data [u32],
Expand Down
Loading

0 comments on commit 65b5e01

Please sign in to comment.