Skip to content

Commit

Permalink
Cleanup hnsw provider to not know about segments
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Nov 11, 2024
1 parent bc48745 commit 1223b7f
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 229 deletions.
124 changes: 25 additions & 99 deletions rust/index/src/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use chroma_error::{ChromaError, ErrorCodes};
use chroma_types::{Metadata, MetadataValue, MetadataValueConversionError, Segment};
use std::ffi::CString;
use std::ffi::{c_char, c_int};
use std::path::Path;
use std::str::Utf8Error;
use thiserror::Error;
use tracing::instrument;

const DEFAULT_MAX_ELEMENTS: usize = 10000;
const DEFAULT_HNSW_M: usize = 16;
const DEFAULT_HNSW_EF_CONSTRUCTION: usize = 100;
const DEFAULT_HNSW_EF_SEARCH: usize = 10;
pub const DEFAULT_MAX_ELEMENTS: usize = 10000;
pub const DEFAULT_HNSW_M: usize = 16;
pub const DEFAULT_HNSW_EF_CONSTRUCTION: usize = 100;
pub const DEFAULT_HNSW_EF_SEARCH: usize = 10;

// https://doc.rust-lang.org/nomicon/ffi.html#representing-opaque-structs
#[repr(C)]
Expand Down Expand Up @@ -50,10 +51,23 @@ impl ChromaError for HnswIndexFromSegmentError {
}

impl HnswIndexConfig {
pub fn from_segment(
segment: &Segment,
persist_path: &std::path::Path,
) -> Result<HnswIndexConfig, Box<HnswIndexFromSegmentError>> {
pub fn new_default() -> Self {
HnswIndexConfig {
max_elements: DEFAULT_MAX_ELEMENTS,
m: DEFAULT_HNSW_M,
ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION,
ef_search: DEFAULT_HNSW_EF_SEARCH,
random_seed: 0,
persist_path: "".to_string(),
}
}

pub fn new(
m: usize,
ef_construction: usize,
ef_search: usize,
persist_path: &Path,
) -> Result<Self, Box<HnswIndexFromSegmentError>> {
let persist_path = match persist_path.to_str() {
Some(persist_path) => persist_path,
None => {
Expand All @@ -62,53 +76,11 @@ impl HnswIndexConfig {
)))
}
};
let metadata = match &segment.metadata {
Some(metadata) => metadata,
None => {
// TODO: This should error, but the configuration is not stored correctly
// after the configuration is refactored to be always stored and doesn't rely on defaults we can fix this
return Ok(HnswIndexConfig {
max_elements: DEFAULT_MAX_ELEMENTS,
m: DEFAULT_HNSW_M,
ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION,
ef_search: DEFAULT_HNSW_EF_SEARCH,
random_seed: 0,
persist_path: persist_path.to_string(),
});
}
};

fn get_metadata_value_as<'a, T>(
metadata: &'a Metadata,
key: &str,
) -> Result<T, Box<HnswIndexFromSegmentError>>
where
T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
{
let res = match metadata.get(key) {
Some(value) => T::try_from(value),
None => {
return Err(Box::new(HnswIndexFromSegmentError::MissingConfig(
key.to_string(),
)))
}
};
match res {
Ok(value) => Ok(value),
Err(e) => Err(Box::new(HnswIndexFromSegmentError::MetadataValueError(e))),
}
}

let m = get_metadata_value_as::<i64>(metadata, "hnsw:M").unwrap_or(DEFAULT_HNSW_M as i64);
let ef_construction = get_metadata_value_as::<i64>(metadata, "hnsw:construction_ef")
.unwrap_or(DEFAULT_HNSW_EF_CONSTRUCTION as i64);
let ef_search = get_metadata_value_as::<i64>(metadata, "hnsw:search_ef")
.unwrap_or(DEFAULT_HNSW_EF_SEARCH as i64);
Ok(HnswIndexConfig {
max_elements: DEFAULT_MAX_ELEMENTS,
m: m as usize,
ef_construction: ef_construction as usize,
ef_search: ef_search as usize,
m,
ef_construction,
ef_search,
random_seed: 0,
persist_path: persist_path.to_string(),
})
Expand Down Expand Up @@ -826,52 +798,6 @@ pub mod test {
});
}

#[test]
fn parameter_defaults() {
let segment = Segment {
id: SegmentUuid::new(),
r#type: chroma_types::SegmentType::HnswDistributed,
scope: chroma_types::SegmentScope::VECTOR,
metadata: Some(HashMap::new()),
collection: CollectionUuid(Uuid::new_v4()),
file_path: HashMap::new(),
};

let persist_path = tempdir().unwrap().path().to_owned();
let config = HnswIndexConfig::from_segment(&segment, &persist_path)
.expect("Failed to create config from segment");

assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS);
assert_eq!(config.m, DEFAULT_HNSW_M);
assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION);
assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH);
assert_eq!(config.random_seed, 0);
assert_eq!(config.persist_path, persist_path.to_str().unwrap());

// Try partial metadata
let mut metadata = HashMap::new();
metadata.insert("hnsw:M".to_string(), MetadataValue::Int(10_i64));

let segment = Segment {
id: SegmentUuid::new(),
r#type: chroma_types::SegmentType::HnswDistributed,
scope: chroma_types::SegmentScope::VECTOR,
metadata: Some(metadata),
collection: CollectionUuid(Uuid::new_v4()),
file_path: HashMap::new(),
};

let config = HnswIndexConfig::from_segment(&segment, &persist_path)
.expect("Failed to create config from segment");

assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS);
assert_eq!(config.m, 10);
assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION);
assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH);
assert_eq!(config.random_seed, 0);
assert_eq!(config.persist_path, persist_path.to_str().unwrap());
}

#[test]
fn it_can_catch_error() {
let n = 10;
Expand Down
120 changes: 61 additions & 59 deletions rust/index/src/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{
use async_trait::async_trait;
use chroma_cache::Cache;
use chroma_config::Configurable;
use chroma_distance::DistanceFunction;
use chroma_error::ChromaError;
use chroma_error::ErrorCodes;
use chroma_storage::Storage;
Expand All @@ -32,6 +33,12 @@ const FILES: [&str; 4] = [
"link_lists.bin",
];

pub type HnswIndexParams = (
usize, /* m */
usize, /* ef_construction */
usize, /* ef_search */
);

// The key of the cache is the collection id and the value is
// the HNSW index for that collection. This restricts the cache to
// contain atmost one index per collection. Ideally, we would like
Expand Down Expand Up @@ -151,8 +158,9 @@ impl HnswIndexProvider {
pub async fn fork(
&self,
source_id: &IndexUuid,
segment: &Segment,
collection_id: &CollectionUuid,
dimensionality: i32,
distance_function: DistanceFunction,
) -> Result<HnswIndexRef, Box<HnswIndexProviderForkError>> {
let new_id = IndexUuid(Uuid::new_v4());
let new_storage_path = self.temporary_storage_path.join(new_id.to_string());
Expand All @@ -175,22 +183,7 @@ impl HnswIndexProvider {
}
}

let index_config = IndexConfig::from_segment(segment, dimensionality);

let index_config = match index_config {
Ok(index_config) => index_config,
Err(e) => {
return Err(Box::new(HnswIndexProviderForkError::IndexConfigError(*e)));
}
};

let hnsw_config = HnswIndexConfig::from_segment(segment, &new_storage_path);
match hnsw_config {
Ok(hnsw_config) => hnsw_config,
Err(e) => {
return Err(Box::new(HnswIndexProviderForkError::HnswConfigError(*e)));
}
};
let index_config = IndexConfig::new(dimensionality, distance_function);

let storage_path_str = match new_storage_path.to_str() {
Some(storage_path_str) => storage_path_str,
Expand All @@ -204,13 +197,15 @@ impl HnswIndexProvider {
match HnswIndex::load(storage_path_str, &index_config, new_id) {
Ok(index) => {
let _guard = self.write_mutex.lock().await;
match self.get(&new_id, &segment.collection).await {
match self.get(&new_id, collection_id).await {
Some(index) => Ok(index.clone()),
None => {
let index = HnswIndexRef {
inner: Arc::new(RwLock::new(index)),
};
self.cache.insert(segment.collection, index.clone()).await;
self.cache
.insert(collection_id.clone(), index.clone())
.await;
Ok(index)
}
}
Expand Down Expand Up @@ -295,8 +290,9 @@ impl HnswIndexProvider {
pub async fn open(
&self,
id: &IndexUuid,
segment: &Segment,
collection_id: &CollectionUuid,
dimensionality: i32,
distance_function: DistanceFunction,
) -> Result<HnswIndexRef, Box<HnswIndexProviderOpenError>> {
let index_storage_path = self.temporary_storage_path.join(id.to_string());

Expand All @@ -319,17 +315,7 @@ impl HnswIndexProvider {
}

// Thread safe.
let index_config = IndexConfig::from_segment(segment, dimensionality);
let index_config = match index_config {
Ok(index_config) => index_config,
Err(e) => {
return Err(Box::new(HnswIndexProviderOpenError::IndexConfigError(*e)));
}
};

// Thread safe.
let _hnsw_config = HnswIndexConfig::from_segment(segment, &index_storage_path)
.map_err(|e| Box::new(HnswIndexProviderOpenError::HnswConfigError(*e)))?;
let index_config = IndexConfig::new(dimensionality, distance_function);

let index_storage_path_str = match index_storage_path.to_str() {
Some(index_storage_path_str) => index_storage_path_str,
Expand All @@ -343,13 +329,13 @@ impl HnswIndexProvider {
match HnswIndex::load(index_storage_path_str, &index_config, *id) {
Ok(index) => {
let _guard = self.write_mutex.lock().await;
match self.get(id, &segment.collection).await {
match self.get(id, collection_id).await {
Some(index) => Ok(index.clone()),
None => {
let index = HnswIndexRef {
inner: Arc::new(RwLock::new(index)),
};
self.cache.insert(segment.collection, index.clone()).await;
self.cache.insert(*collection_id, index.clone()).await;
Ok(index)
}
}
Expand All @@ -370,9 +356,11 @@ impl HnswIndexProvider {
// A query comes in and the index is not in the cache -> we need to load the index from s3 based on the segment files id
pub async fn create(
&self,
// TODO: This should not take Segment. The index layer should not know about the segment concept
segment: &Segment,
collection_id: &Uuid,
hnsw_params: HnswIndexParams,
persist_path: &std::path::Path,
dimensionality: i32,
distance_function: DistanceFunction,
) -> Result<HnswIndexRef, Box<HnswIndexProviderCreateError>> {
let id = IndexUuid(Uuid::new_v4());
let index_storage_path = self.temporary_storage_path.join(id.to_string());
Expand All @@ -384,31 +372,30 @@ impl HnswIndexProvider {
}
}

let index_config = match IndexConfig::from_segment(segment, dimensionality) {
Ok(index_config) => index_config,
Err(e) => {
return Err(Box::new(HnswIndexProviderCreateError::IndexConfigError(*e)));
}
};
let index_config = IndexConfig::new(dimensionality, distance_function);

let hnsw_config =
match HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, persist_path) {
Ok(hnsw_config) => hnsw_config,
Err(e) => {
return Err(Box::new(HnswIndexProviderCreateError::HnswConfigError(*e)));
}
};

let hnsw_config = match HnswIndexConfig::from_segment(segment, &index_storage_path) {
Ok(hnsw_config) => hnsw_config,
Err(e) => {
return Err(Box::new(HnswIndexProviderCreateError::HnswConfigError(*e)));
}
};
// HnswIndex init is not thread safe. We should not call it from multiple threads
let index = HnswIndex::init(&index_config, Some(&hnsw_config), id)
.map_err(|e| Box::new(HnswIndexProviderCreateError::IndexInitError(e)))?;

let _guard = self.write_mutex.lock().await;
match self.get(&id, &segment.collection).await {
match self.get(&id, &collection_id).await {
Some(index) => Ok(index.clone()),
None => {
let index = HnswIndexRef {
inner: Arc::new(RwLock::new(index)),
};
self.cache.insert(segment.collection, index.clone()).await;
self.cache
.insert(collection_id.clone(), index.clone())
.await;
Ok(index)
}
}
Expand Down Expand Up @@ -611,6 +598,8 @@ pub enum HnswIndexProviderFileError {

#[cfg(test)]
mod tests {
use crate::{DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M};

use super::*;
use chroma_cache::new_non_persistent_cache_for_test;
use chroma_storage::local::LocalStorage;
Expand All @@ -629,21 +618,34 @@ mod tests {
let cache = new_non_persistent_cache_for_test();
let (_tx, rx) = tokio::sync::mpsc::unbounded_channel();
let provider = HnswIndexProvider::new(storage, hnsw_tmp_path, cache, rx);
let segment = Segment {
id: SegmentUuid::new(),
r#type: SegmentType::HnswDistributed,
scope: chroma_types::SegmentScope::VECTOR,
collection: CollectionUuid(Uuid::new_v4()),
metadata: None,
file_path: HashMap::new(),
};
let collection_id = CollectionUuid(Uuid::new_v4());

let dimensionality = 128;
let created_index = provider.create(&segment, dimensionality).await.unwrap();
let hnsw_params = (
DEFAULT_HNSW_M,
DEFAULT_HNSW_EF_CONSTRUCTION,
DEFAULT_HNSW_EF_SEARCH,
);
let distance_function = DistanceFunction::Euclidean;
let created_index = provider
.create(
&collection_id,
hnsw_params,
&provider.temporary_storage_path,
dimensionality,
distance_function.clone(),
)
.await
.unwrap();
let created_index_id = created_index.inner.read().id;

let forked_index = provider
.fork(&created_index_id, &segment, dimensionality)
.fork(
&created_index_id,
&collection_id,
dimensionality,
distance_function,
)
.await
.unwrap();
let forked_index_id = forked_index.inner.read().id;
Expand Down
Loading

0 comments on commit 1223b7f

Please sign in to comment.