Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Nov 13, 2024
1 parent 1e2fe06 commit 12db19e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 49 deletions.
53 changes: 21 additions & 32 deletions rust/index/src/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ const FILES: [&str; 4] = [
"link_lists.bin",
];

pub type HnswIndexParams = (
usize, /* m */
usize, /* ef_construction */
usize, /* ef_search */
);
type CacheKey = CollectionUuid;

// The key of the cache is the collection id and the value is
// the HNSW index for that collection. This restricts the cache to
Expand Down Expand Up @@ -140,12 +136,8 @@ impl HnswIndexProvider {
}
}

pub async fn get(
&self,
index_id: &IndexUuid,
collection_id: &CollectionUuid,
) -> Option<HnswIndexRef> {
match self.cache.get(collection_id).await.ok().flatten() {
pub async fn get(&self, index_id: &IndexUuid, cache_key: &CacheKey) -> Option<HnswIndexRef> {
match self.cache.get(cache_key).await.ok().flatten() {
Some(index) => {
let index_with_lock = index.inner.read();
if index_with_lock.id == *index_id {
Expand All @@ -166,7 +158,7 @@ impl HnswIndexProvider {
pub async fn fork(
&self,
source_id: &IndexUuid,
collection_id: &CollectionUuid,
cache_key: &CacheKey,
dimensionality: i32,
distance_function: DistanceFunction,
) -> Result<HnswIndexRef, Box<HnswIndexProviderForkError>> {
Expand Down Expand Up @@ -205,13 +197,13 @@ impl HnswIndexProvider {
match HnswIndex::load(storage_path_str, &index_config, new_id) {
Ok(index) => {
let _guard = self.write_mutex.lock().await;
match self.get(&new_id, collection_id).await {
match self.get(&new_id, cache_key).await {
Some(index) => Ok(index.clone()),
None => {
let index = HnswIndexRef {
inner: Arc::new(RwLock::new(index)),
};
self.cache.insert(*collection_id, index.clone()).await;
self.cache.insert(*cache_key, index.clone()).await;
Ok(index)
}
}
Expand Down Expand Up @@ -296,7 +288,7 @@ impl HnswIndexProvider {
pub async fn open(
&self,
id: &IndexUuid,
collection_id: &CollectionUuid,
cache_key: &CacheKey,
dimensionality: i32,
distance_function: DistanceFunction,
) -> Result<HnswIndexRef, Box<HnswIndexProviderOpenError>> {
Expand Down Expand Up @@ -335,13 +327,13 @@ impl HnswIndexProvider {
match HnswIndex::load(index_storage_path_str, &index_config, *id) {
Ok(index) => {
let _guard = self.write_mutex.lock().await;
match self.get(id, collection_id).await {
match self.get(id, cache_key).await {
Some(index) => Ok(index.clone()),
None => {
let index = HnswIndexRef {
inner: Arc::new(RwLock::new(index)),
};
self.cache.insert(*collection_id, index.clone()).await;
self.cache.insert(*cache_key, index.clone()).await;
Ok(index)
}
}
Expand All @@ -362,9 +354,10 @@ impl HnswIndexProvider {
// A query comes in and the index is not in the cache -> we need to load the index from s3 based on the segment files id
pub async fn create(
&self,
collection_id: &CollectionUuid,
hnsw_params: HnswIndexParams,
persist_path: &std::path::Path,
cache_key: &CacheKey,
m: usize,
ef_construction: usize,
ef_search: usize,
dimensionality: i32,
distance_function: DistanceFunction,
) -> Result<HnswIndexRef, Box<HnswIndexProviderCreateError>> {
Expand All @@ -381,7 +374,7 @@ impl HnswIndexProvider {
let index_config = IndexConfig::new(dimensionality, distance_function);

let hnsw_config =
match HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, persist_path) {
match HnswIndexConfig::new(m, ef_construction, ef_search, &index_storage_path) {
Ok(hnsw_config) => hnsw_config,
Err(e) => {
return Err(Box::new(HnswIndexProviderCreateError::HnswConfigError(*e)));
Expand All @@ -393,13 +386,13 @@ impl HnswIndexProvider {
.map_err(|e| Box::new(HnswIndexProviderCreateError::IndexInitError(e)))?;

let _guard = self.write_mutex.lock().await;
match self.get(&id, collection_id).await {
match self.get(&id, cache_key).await {
Some(index) => Ok(index.clone()),
None => {
let index = HnswIndexRef {
inner: Arc::new(RwLock::new(index)),
};
self.cache.insert(*collection_id, index.clone()).await;
self.cache.insert(*cache_key, index.clone()).await;
Ok(index)
}
}
Expand Down Expand Up @@ -438,8 +431,8 @@ impl HnswIndexProvider {
}

/// Purge entries from the cache by index ID and remove temporary files from disk.
pub async fn purge_by_id(&mut self, collection_uuids: &[CollectionUuid]) {
for collection_uuid in collection_uuids {
pub async fn purge_by_id(&mut self, cache_keys: &[CacheKey]) {
for collection_uuid in cache_keys {
let Some(index_id) = self
.cache
.get(collection_uuid)
Expand Down Expand Up @@ -623,17 +616,13 @@ mod tests {
let collection_id = CollectionUuid(Uuid::new_v4());

let dimensionality = 128;
let hnsw_params = (
DEFAULT_HNSW_M,
DEFAULT_HNSW_EF_CONSTRUCTION,
DEFAULT_HNSW_EF_SEARCH,
);
let distance_function = DistanceFunction::Euclidean;
let created_index = provider
.create(
&collection_id,
hnsw_params,
&provider.temporary_storage_path,
DEFAULT_HNSW_M,
DEFAULT_HNSW_EF_CONSTRUCTION,
DEFAULT_HNSW_EF_SEARCH,
dimensionality,
distance_function.clone(),
)
Expand Down
52 changes: 35 additions & 17 deletions rust/worker/src/segment/distributed_hnsw_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use async_trait::async_trait;
use chroma_distance::{DistanceFunction, DistanceFunctionError};
use chroma_error::{ChromaError, ErrorCodes};
use chroma_index::hnsw_provider::{
HnswIndexParams, HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError,
HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError,
HnswIndexProviderOpenError, HnswIndexRef,
};
use chroma_index::{Index, IndexUuid};
Expand All @@ -18,6 +18,12 @@ use uuid::Uuid;

const HNSW_INDEX: &str = "hnsw_index";

pub struct HnswIndexParamsFromSegment {
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
}

#[derive(Clone)]
pub(crate) struct DistributedHNSWSegmentWriter {
index: HnswIndexRef,
Expand Down Expand Up @@ -65,15 +71,15 @@ impl ChromaError for DistributedHNSWSegmentFromSegmentError {
}
}

fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams {
fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParamsFromSegment {
let metadata = match &segment.metadata {
Some(metadata) => metadata,
None => {
return (
DEFAULT_HNSW_M,
DEFAULT_HNSW_EF_CONSTRUCTION,
DEFAULT_HNSW_EF_SEARCH,
);
return HnswIndexParamsFromSegment {
m: DEFAULT_HNSW_M,
ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION,
ef_search: DEFAULT_HNSW_EF_SEARCH,
};
}
};

Expand All @@ -90,7 +96,11 @@ fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams {
Err(_) => DEFAULT_HNSW_EF_SEARCH,
};

(m, ef_construction, ef_search)
HnswIndexParamsFromSegment {
m,
ef_construction,
ef_search,
}
}

pub fn distance_function_from_segment(
Expand Down Expand Up @@ -130,7 +140,6 @@ impl DistributedHNSWSegmentWriter {
hnsw_index_provider: HnswIndexProvider,
) -> Result<Box<DistributedHNSWSegmentWriter>, Box<DistributedHNSWSegmentFromSegmentError>>
{
let persist_path = &hnsw_index_provider.temporary_storage_path;
// TODO: this is hacky, we use the presence of files to determine if we need to load or create the index
// ideally, an explicit state would be better. When we implement distributed HNSW segments,
// we can introduce a state in the segment metadata for this
Expand Down Expand Up @@ -205,8 +214,9 @@ impl DistributedHNSWSegmentWriter {
let index = match hnsw_index_provider
.create(
&segment.collection,
hnsw_params,
persist_path,
hnsw_params.m,
hnsw_params.ef_construction,
hnsw_params.ef_search,
dimensionality as i32,
distance_function,
)
Expand Down Expand Up @@ -445,9 +455,13 @@ pub mod test {
};

let hnsw_params = hnsw_params_from_segment(&segment);
let config =
HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, &persist_path)
.expect("Error creating hnsw index config");
let config = HnswIndexConfig::new(
hnsw_params.m,
hnsw_params.ef_construction,
hnsw_params.ef_search,
&persist_path,
)
.expect("Error creating hnsw index config");

assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS);
assert_eq!(config.m, DEFAULT_HNSW_M);
Expand All @@ -470,9 +484,13 @@ pub mod test {
};

let hnsw_params = hnsw_params_from_segment(&segment);
let config =
HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, &persist_path)
.expect("Error creating hnsw index config");
let config = HnswIndexConfig::new(
hnsw_params.m,
hnsw_params.ef_construction,
hnsw_params.ef_search,
&persist_path,
)
.expect("Error creating hnsw index config");

assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS);
assert_eq!(config.m, 10);
Expand Down

0 comments on commit 12db19e

Please sign in to comment.