Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Introduce spann segment and index #3080

Merged
merged 10 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion rust/blockstore/src/memory/storage.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::key::{CompositeKey, KeyWrapper};
use chroma_error::ChromaError;
use chroma_types::DataRecord;
use chroma_types::{DataRecord, SpannPostingList};
use parking_lot::RwLock;
use roaring::RoaringBitmap;
use std::{
Expand Down Expand Up @@ -585,6 +585,16 @@ impl Writeable for &DataRecord<'_> {
}
}

impl Writeable for &SpannPostingList<'_> {
fn write_to_storage(_: &str, _: KeyWrapper, _: Self, _: &StorageBuilder) {
todo!()
}

fn remove_from_storage(_: &str, _: KeyWrapper, _: &StorageBuilder) {
todo!()
}
}

impl<'referred_data> Readable<'referred_data> for DataRecord<'referred_data> {
fn read_from_storage(
prefix: &str,
Expand Down
11 changes: 4 additions & 7 deletions rust/index/src/hnsw.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use super::{Index, IndexConfig, IndexUuid, PersistentIndex};
use chroma_error::{ChromaError, ErrorCodes};
use chroma_types::MetadataValueConversionError;
use std::ffi::CString;
use std::ffi::{c_char, c_int};
use std::path::Path;
Expand Down Expand Up @@ -37,14 +36,12 @@ pub struct HnswIndexConfig {
}

#[derive(Error, Debug)]
pub enum HnswIndexFromSegmentError {
pub enum HnswIndexConfigError {
#[error("Missing config `{0}`")]
MissingConfig(String),
#[error("Invalid metadata value")]
MetadataValueError(#[from] MetadataValueConversionError),
}

impl ChromaError for HnswIndexFromSegmentError {
impl ChromaError for HnswIndexConfigError {
fn code(&self) -> ErrorCodes {
ErrorCodes::InvalidArgument
}
Expand All @@ -56,11 +53,11 @@ impl HnswIndexConfig {
ef_construction: usize,
ef_search: usize,
persist_path: &Path,
) -> Result<Self, Box<HnswIndexFromSegmentError>> {
) -> Result<Self, Box<HnswIndexConfigError>> {
let persist_path = match persist_path.to_str() {
Some(persist_path) => persist_path,
None => {
return Err(Box::new(HnswIndexFromSegmentError::MissingConfig(
return Err(Box::new(HnswIndexConfigError::MissingConfig(
"persist_path".to_string(),
)))
}
Expand Down
26 changes: 5 additions & 21 deletions rust/index/src/hnsw_provider.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::{HnswIndexConfigError, PersistentIndex};

use super::config::HnswProviderConfig;
use super::{
HnswIndex, HnswIndexConfig, HnswIndexFromSegmentError, Index, IndexConfig,
IndexConfigFromSegmentError, IndexUuid,
};
use crate::PersistentIndex;
use super::{HnswIndex, HnswIndexConfig, Index, IndexConfig, IndexUuid};

use async_trait::async_trait;
use chroma_cache::{Cache, Weighted};
use chroma_config::Configurable;
Expand Down Expand Up @@ -547,12 +546,8 @@ impl HnswIndexProvider {

#[derive(Error, Debug)]
pub enum HnswIndexProviderOpenError {
#[error("Index configuration error")]
IndexConfigError(#[from] IndexConfigFromSegmentError),
#[error("Hnsw index file error")]
FileError(#[from] HnswIndexProviderFileError),
#[error("Hnsw config error")]
HnswConfigError(#[from] HnswIndexFromSegmentError),
#[error("Index load error")]
IndexLoadError(#[from] Box<dyn ChromaError>),
#[error("Path: {0} could not be converted to string")]
Expand All @@ -562,9 +557,7 @@ pub enum HnswIndexProviderOpenError {
impl ChromaError for HnswIndexProviderOpenError {
fn code(&self) -> ErrorCodes {
match self {
HnswIndexProviderOpenError::IndexConfigError(e) => e.code(),
HnswIndexProviderOpenError::FileError(_) => ErrorCodes::Internal,
HnswIndexProviderOpenError::HnswConfigError(e) => e.code(),
HnswIndexProviderOpenError::IndexLoadError(e) => e.code(),
HnswIndexProviderOpenError::PathToStringError(_) => ErrorCodes::InvalidArgument,
}
Expand All @@ -573,12 +566,8 @@ impl ChromaError for HnswIndexProviderOpenError {

#[derive(Error, Debug)]
pub enum HnswIndexProviderForkError {
#[error("Index configuration error")]
IndexConfigError(#[from] IndexConfigFromSegmentError),
#[error("Hnsw index file error")]
FileError(#[from] HnswIndexProviderFileError),
#[error("Hnsw config error")]
HnswConfigError(#[from] HnswIndexFromSegmentError),
#[error("Index load error")]
IndexLoadError(#[from] Box<dyn ChromaError>),
#[error("Path: {0} could not be converted to string")]
Expand All @@ -588,9 +577,7 @@ pub enum HnswIndexProviderForkError {
impl ChromaError for HnswIndexProviderForkError {
fn code(&self) -> ErrorCodes {
match self {
HnswIndexProviderForkError::IndexConfigError(e) => e.code(),
HnswIndexProviderForkError::FileError(_) => ErrorCodes::Internal,
HnswIndexProviderForkError::HnswConfigError(e) => e.code(),
HnswIndexProviderForkError::IndexLoadError(e) => e.code(),
HnswIndexProviderForkError::PathToStringError(_) => ErrorCodes::InvalidArgument,
}
Expand All @@ -599,20 +586,17 @@ impl ChromaError for HnswIndexProviderForkError {

#[derive(Error, Debug)]
pub enum HnswIndexProviderCreateError {
#[error("Index configuration error")]
IndexConfigError(#[from] IndexConfigFromSegmentError),
#[error("Hnsw index file error")]
FileError(#[from] HnswIndexProviderFileError),
#[error("Hnsw config error")]
HnswConfigError(#[from] HnswIndexFromSegmentError),
HnswConfigError(#[from] HnswIndexConfigError),
#[error("Index init error")]
IndexInitError(#[from] Box<dyn ChromaError>),
}

impl ChromaError for HnswIndexProviderCreateError {
fn code(&self) -> ErrorCodes {
match self {
HnswIndexProviderCreateError::IndexConfigError(e) => e.code(),
HnswIndexProviderCreateError::FileError(_) => ErrorCodes::Internal,
HnswIndexProviderCreateError::HnswConfigError(e) => e.code(),
HnswIndexProviderCreateError::IndexInitError(e) => e.code(),
Expand Down
3 changes: 3 additions & 0 deletions rust/index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod fulltext;
mod hnsw;
pub mod hnsw_provider;
pub mod metadata;
pub mod spann;
mod types;
pub mod utils;

Expand All @@ -12,6 +13,8 @@ use chroma_cache::new_non_persistent_cache_for_test;
use chroma_storage::test_storage;
pub use hnsw::*;
use hnsw_provider::HnswIndexProvider;
#[allow(unused_imports)]
pub use spann::*;
use tempfile::tempdir;
pub use types::*;

Expand Down
1 change: 1 addition & 0 deletions rust/index/src/spann.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod types;
220 changes: 220 additions & 0 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
use std::collections::HashMap;

use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter, BlockfileWriterOptions};
use chroma_distance::DistanceFunction;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_types::{CollectionUuid, SpannPostingList};
use thiserror::Error;
use uuid::Uuid;

use crate::{
hnsw_provider::{HnswIndexProvider, HnswIndexRef},
IndexUuid,
};

// TODO(Sanket): Add locking structures as necessary.
#[allow(dead_code)]
pub struct SpannIndexWriter {
// HNSW index and its provider for centroid search.
hnsw_index: HnswIndexRef,
hnsw_provider: HnswIndexProvider,
// Posting list of the centroids.
// The blockfile also contains next id for the head.
sanketkedia marked this conversation as resolved.
Show resolved Hide resolved
posting_list_writer: BlockfileWriter,
// Version number of each point.
versions_map: HashMap<u32, u32>,
HammadB marked this conversation as resolved.
Show resolved Hide resolved
}

#[derive(Error, Debug)]
pub enum SpannIndexWriterConstructionError {
#[error("HNSW index construction error")]
HnswIndexConstructionError,
#[error("Blockfile reader construction error")]
BlockfileReaderConstructionError,
#[error("Blockfile writer construction error")]
BlockfileWriterConstructionError,
#[error("Error loading version data from blockfile")]
BlockfileVersionDataLoadError,
}

impl ChromaError for SpannIndexWriterConstructionError {
fn code(&self) -> ErrorCodes {
match self {
Self::HnswIndexConstructionError => ErrorCodes::Internal,
Self::BlockfileReaderConstructionError => ErrorCodes::Internal,
Self::BlockfileWriterConstructionError => ErrorCodes::Internal,
Self::BlockfileVersionDataLoadError => ErrorCodes::Internal,
}
}
}

impl SpannIndexWriter {
pub fn new(
hnsw_index: HnswIndexRef,
hnsw_provider: HnswIndexProvider,
posting_list_writer: BlockfileWriter,
versions_map: HashMap<u32, u32>,
) -> Self {
SpannIndexWriter {
hnsw_index,
hnsw_provider,
posting_list_writer,
versions_map,
}
}

async fn hnsw_index_from_id(
hnsw_provider: &HnswIndexProvider,
id: &IndexUuid,
collection_id: &CollectionUuid,
distance_function: DistanceFunction,
dimensionality: usize,
) -> Result<HnswIndexRef, SpannIndexWriterConstructionError> {
match hnsw_provider
.fork(id, collection_id, dimensionality as i32, distance_function)
.await
{
Ok(index) => Ok(index),
Err(_) => Err(SpannIndexWriterConstructionError::HnswIndexConstructionError),
}
}

async fn create_hnsw_index(
hnsw_provider: &HnswIndexProvider,
collection_id: &CollectionUuid,
distance_function: DistanceFunction,
dimensionality: usize,
m: usize,
ef_construction: usize,
ef_search: usize,
) -> Result<HnswIndexRef, SpannIndexWriterConstructionError> {
match hnsw_provider
.create(
collection_id,
m,
ef_construction,
ef_search,
dimensionality as i32,
distance_function,
)
.await
{
Ok(index) => Ok(index),
Err(_) => Err(SpannIndexWriterConstructionError::HnswIndexConstructionError),
}
}

async fn load_versions_map(
blockfile_id: &Uuid,
blockfile_provider: &BlockfileProvider,
) -> Result<HashMap<u32, u32>, SpannIndexWriterConstructionError> {
// Create a reader for the blockfile. Load all the data into the versions map.
let mut versions_map = HashMap::new();
let reader = match blockfile_provider.read::<u32, u32>(blockfile_id).await {
Ok(reader) => reader,
Err(_) => {
return Err(SpannIndexWriterConstructionError::BlockfileReaderConstructionError)
}
};
// Load data using the reader.
let versions_data = reader
.get_range(.., ..)
.await
.map_err(|_| SpannIndexWriterConstructionError::BlockfileVersionDataLoadError)?;
versions_data.iter().for_each(|(key, value)| {
versions_map.insert(*key, *value);
});
Ok(versions_map)
}

async fn fork_postings_list(
blockfile_id: &Uuid,
blockfile_provider: &BlockfileProvider,
) -> Result<BlockfileWriter, SpannIndexWriterConstructionError> {
let mut bf_options = BlockfileWriterOptions::new();
bf_options = bf_options.unordered_mutations();
bf_options = bf_options.fork(*blockfile_id);
match blockfile_provider
.write::<u32, &SpannPostingList<'_>>(bf_options)
.await
{
Ok(writer) => Ok(writer),
Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError),
}
}

async fn create_posting_list(
blockfile_provider: &BlockfileProvider,
) -> Result<BlockfileWriter, SpannIndexWriterConstructionError> {
let mut bf_options = BlockfileWriterOptions::new();
bf_options = bf_options.unordered_mutations();
match blockfile_provider
.write::<u32, &SpannPostingList<'_>>(bf_options)
.await
{
Ok(writer) => Ok(writer),
Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError),
}
}

#[allow(clippy::too_many_arguments)]
pub async fn from_id(
hnsw_provider: &HnswIndexProvider,
hnsw_id: Option<&IndexUuid>,
versions_map_id: Option<&Uuid>,
posting_list_id: Option<&Uuid>,
m: Option<usize>,
ef_construction: Option<usize>,
ef_search: Option<usize>,
collection_id: &CollectionUuid,
distance_function: DistanceFunction,
dimensionality: usize,
blockfile_provider: &BlockfileProvider,
) -> Result<Self, SpannIndexWriterConstructionError> {
// Create the HNSW index.
let hnsw_index = match hnsw_id {
Some(hnsw_id) => {
Self::hnsw_index_from_id(
hnsw_provider,
hnsw_id,
collection_id,
distance_function,
dimensionality,
)
.await?
}
None => {
Self::create_hnsw_index(
hnsw_provider,
collection_id,
distance_function,
dimensionality,
m.unwrap(), // Safe since caller should always provide this.
Copy link
Collaborator

@HammadB HammadB Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it an option if caller always provides it? Is the assumption that If you don't give an hnsw_id you will give these values?

Can we force input to be correct by construction?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline - will switch to two methods

ef_construction.unwrap(), // Safe since caller should always provide this.
ef_search.unwrap(), // Safe since caller should always provide this.
)
.await?
}
};
// Load the versions map.
let versions_map = match versions_map_id {
Some(versions_map_id) => {
Self::load_versions_map(versions_map_id, blockfile_provider).await?
}
None => HashMap::new(),
};
// Fork the posting list writer.
let posting_list_writer = match posting_list_id {
Some(posting_list_id) => {
Self::fork_postings_list(posting_list_id, blockfile_provider).await?
}
None => Self::create_posting_list(blockfile_provider).await?,
};
Ok(Self::new(
hnsw_index,
hnsw_provider.clone(),
posting_list_writer,
versions_map,
))
}
}
Loading
Loading