From ea38dca521ff6d391485324a62915cd4e6aaec22 Mon Sep 17 00:00:00 2001 From: hammadb Date: Sun, 17 Dec 2023 16:58:47 -0800 Subject: [PATCH] Add segment manager, hnsw segment --- rust/worker/src/index/mod.rs | 1 + rust/worker/src/ingest/scheduler.rs | 6 +- rust/worker/src/lib.rs | 10 +- .../src/segment/distributed_hnsw_segment.rs | 65 +++++++ rust/worker/src/segment/mod.rs | 5 +- rust/worker/src/segment/segment_ingestor.rs | 16 +- rust/worker/src/segment/segment_manager.rs | 159 ++++++++++++++++++ 7 files changed, 253 insertions(+), 9 deletions(-) create mode 100644 rust/worker/src/segment/distributed_hnsw_segment.rs create mode 100644 rust/worker/src/segment/segment_manager.rs diff --git a/rust/worker/src/index/mod.rs b/rust/worker/src/index/mod.rs index 007387584076..ddaf8d737a46 100644 --- a/rust/worker/src/index/mod.rs +++ b/rust/worker/src/index/mod.rs @@ -3,4 +3,5 @@ mod types; mod utils; // Re-export types +pub(crate) use hnsw::*; pub(crate) use types::*; diff --git a/rust/worker/src/ingest/scheduler.rs b/rust/worker/src/ingest/scheduler.rs index 8f82e143f84a..4589c84b8009 100644 --- a/rust/worker/src/ingest/scheduler.rs +++ b/rust/worker/src/ingest/scheduler.rs @@ -81,7 +81,11 @@ impl Component for RoundRobinScheduler { // Randomly pick a subscriber to send the message to // This serves as a crude load balancing between available threads // Future improvements here could be - // - + // - Use a work stealing scheduler + // - Use rayon + // - We need to enforce partial order over writes to a given key + // so we need a mechanism to ensure that all writes to a given key + // occur in order let mut subscriber = None; { let mut rng = rand::thread_rng(); diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index f5aa6978714b..e24bf64c416c 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -44,10 +44,18 @@ pub async fn worker_entrypoint() { let mut scheduler = ingest::RoundRobinScheduler::new(); + let segment_manager = match segment::SegmentManager::try_from_config(&config.worker).await { + Ok(segment_manager) => segment_manager, + Err(err) => { + println!("Failed to create segment manager component: {:?}", err); + return; + } + }; + let mut segment_ingestor_receivers = Vec::with_capacity(config.worker.num_indexing_threads as usize); for _ in 0..config.worker.num_indexing_threads { - let segment_ingestor = segment::SegmentIngestor::new(); + let segment_ingestor = segment::SegmentIngestor::new(segment_manager.clone()); let segment_ingestor_handle = system.start_component(segment_ingestor); let recv = segment_ingestor_handle.receiver(); segment_ingestor_receivers.push(recv); diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs new file mode 100644 index 000000000000..75f9e1e90f7f --- /dev/null +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -0,0 +1,65 @@ +use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; + +use crate::errors::ChromaError; +use crate::index::{HnswIndex, HnswIndexConfig, Index, IndexConfig}; +use crate::types::{EmbeddingRecord, Operation}; + +pub(crate) struct DistributedHNSWSegment { + index: Arc>, + id: AtomicUsize, + index_config: IndexConfig, + hnsw_config: HnswIndexConfig, +} + +impl DistributedHNSWSegment { + pub(crate) fn new( + index_config: IndexConfig, + hnsw_config: HnswIndexConfig, + ) -> Result> { + let hnsw_index = HnswIndex::init(&index_config, Some(&hnsw_config)); + let hnsw_index = match hnsw_index { + Ok(index) => index, + Err(e) => { + // TODO: log + handle an error that we failed to init the index + return Err(e); + } + }; + let index = Arc::new(RwLock::new(hnsw_index)); + return Ok(DistributedHNSWSegment { + index: index, + id: AtomicUsize::new(0), + index_config: index_config, + hnsw_config, + }); + } + + pub(crate) fn write_records(&mut self, records: Vec>) { + for record in records { + let op = Operation::try_from(record.operation); + match op { + Ok(Operation::Add) => { + // TODO: make lock xor lock + match record.embedding { + Some(vector) => { + let next_id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + println!("Adding item: {}", next_id); + self.index.write().add(next_id, &vector); + } + None => { + // TODO: log an error + println!("No vector found in record"); + } + } + } + Ok(Operation::Upsert) => {} + Ok(Operation::Update) => {} + Ok(Operation::Delete) => {} + Err(_) => { + println!("Error parsing operation"); + } + } + } + } +} diff --git a/rust/worker/src/segment/mod.rs b/rust/worker/src/segment/mod.rs index 170811b22ce4..25c6d0a17151 100644 --- a/rust/worker/src/segment/mod.rs +++ b/rust/worker/src/segment/mod.rs @@ -1,3 +1,6 @@ +mod distributed_hnsw_segment; mod segment_ingestor; +mod segment_manager; -pub use segment_ingestor::*; +pub(crate) use segment_ingestor::*; +pub(crate) use segment_manager::*; diff --git a/rust/worker/src/segment/segment_ingestor.rs b/rust/worker/src/segment/segment_ingestor.rs index 10d0dde2f01e..06a5f54ce260 100644 --- a/rust/worker/src/segment/segment_ingestor.rs +++ b/rust/worker/src/segment/segment_ingestor.rs @@ -10,7 +10,11 @@ use crate::{ types::EmbeddingRecord, }; -pub(crate) struct SegmentIngestor {} +use super::segment_manager::{self, SegmentManager}; + +pub(crate) struct SegmentIngestor { + segment_manager: SegmentManager, +} impl Component for SegmentIngestor { fn queue_size(&self) -> usize { @@ -28,8 +32,10 @@ impl Debug for SegmentIngestor { } impl SegmentIngestor { - pub(crate) fn new() -> Self { - SegmentIngestor {} + pub(crate) fn new(segment_manager: SegmentManager) -> Self { + SegmentIngestor { + segment_manager: segment_manager, + } } } @@ -37,8 +43,6 @@ impl SegmentIngestor { impl Handler> for SegmentIngestor { async fn handle(&mut self, message: Arc, ctx: &ComponentContext) { println!("INGEST: ID of embedding is {}", message.id); - // let segment_manager = ctx.system.get_segment_manager(); - // let segment = segment_manager.get_segment(&tenant); - // segment.ingest(embedding); + self.segment_manager.write_record(message).await; } } diff --git a/rust/worker/src/segment/segment_manager.rs b/rust/worker/src/segment/segment_manager.rs new file mode 100644 index 000000000000..a435ed4e2d6b --- /dev/null +++ b/rust/worker/src/segment/segment_manager.rs @@ -0,0 +1,159 @@ +use crate::{ + config::{Configurable, WorkerConfig}, + errors::ChromaError, + sysdb::sysdb::{GrpcSysDb, SysDb}, +}; +use async_trait::async_trait; +use parking_lot::{ + MappedRwLockReadGuard, RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard, +}; +use std::collections::HashMap; +use std::sync::Arc; +use uuid::Uuid; + +use super::distributed_hnsw_segment::DistributedHNSWSegment; +use crate::types::{EmbeddingRecord, MetadataValue, Segment, SegmentScope}; + +#[derive(Clone)] +pub(crate) struct SegmentManager { + inner: Arc, + sysdb: Box, +} + +struct Inner { + vector_segments: RwLock>>, + collection_to_segment_cache: RwLock>>, +} + +impl SegmentManager { + pub(crate) fn new(sysdb: Box) -> Self { + SegmentManager { + inner: Arc::new(Inner { + vector_segments: RwLock::new(HashMap::new()), + collection_to_segment_cache: RwLock::new(HashMap::new()), + }), + sysdb: sysdb, + } + } + + pub(crate) async fn write_record(&mut self, record: Arc) { + println!( + "Manager is writing record for collection: {}", + record.collection_id + ); + let collection_id = record.collection_id; + let mut target_segment_id = None; + + // TODO: don't assume 1:1 mapping between collection and segment + { + let segments = self.get_segments(&collection_id).await; + // For now we assume segment is 1:1 with collection + target_segment_id = match segments { + Ok(segments) => { + if segments.len() == 0 { + return; // TODO: handle no segment found + } + Some(segments[0].id) + } + Err(_) => None, + }; + } + + if target_segment_id.is_none() { + return; // TODO: handle no segment found + } + // let target_segment_id = target_segment_id.unwrap(); + println!("Writing record to segment: {}", target_segment_id.unwrap()); + + // let segment_cache = self.inner.vector_segments.upgradable_read(); + // match segment_cache.get(&target_segment_id) { + // Some(segment) => { + // segment.write_records(vec![record]); + // } + // None => { + // let mut segment_cache = RwLockUpgradableReadGuard::upgrade(segment_cache); + // // Parse metadata from the segment and hydrate the params for the segment + // let new_segment = Box::new(DistributedHNSWSegment::new( + // "ip".to_string(), + // 100000, + // "./test/".to_string(), + // 100, + // 10000, + // )); + // segment_cache.insert(target_segment_id.clone(), new_segment); + // let segment_cache = RwLockWriteGuard::downgrade(segment_cache); + // let segment = RwLockReadGuard::map(segment_cache, |cache| { + // return cache.get(&target_segment_id).unwrap(); + // }); + // segment.write_records(vec![record]); + // } + // } + } + + async fn get_segments( + &mut self, + collection_uuid: &Uuid, + ) -> Result>, &'static str> { + let cache_guard = self.inner.collection_to_segment_cache.read(); + // This lets us return a reference to the segments with the lock. The caller is responsible + // dropping the lock. + let segments = + RwLockReadGuard::try_map(cache_guard, |cache: &HashMap>| { + return cache.get(&collection_uuid); + }); + match segments { + Ok(segments) => { + return Ok(segments); + } + Err(_) => { + // Data was not in the cache, so we need to get it from the database + // Drop the lock since we need to upgrade it + // Mappable locks cannot be upgraded, so we need to drop the lock and re-acquire it + // https://github.com/Amanieu/parking_lot/issues/83 + drop(segments); + + let segments = self + .sysdb + .get_segments( + None, + None, + Some(SegmentScope::VECTOR), + None, + Some(collection_uuid.clone()), + ) + .await; + match segments { + Ok(segments) => { + let mut cache_guard = self.inner.collection_to_segment_cache.write(); + cache_guard.insert(collection_uuid.clone(), segments); + let cache_guard = RwLockWriteGuard::downgrade(cache_guard); + let segments = RwLockReadGuard::map(cache_guard, |cache| { + // This unwrap is safe because we just inserted the segments into the cache and currently, + // there is no way to remove segments from the cache. + return cache.get(&collection_uuid).unwrap(); + }); + return Ok(segments); + } + Err(e) => { + return Err("Failed to get segments for collection from SysDB"); + } + } + } + } + } +} + +#[async_trait] +impl Configurable for SegmentManager { + async fn try_from_config(worker_config: &WorkerConfig) -> Result> { + // TODO: Sysdb should have a dynamic resolution in sysdb + let sysdb = GrpcSysDb::try_from_config(worker_config).await; + let sysdb = match sysdb { + Ok(sysdb) => sysdb, + Err(err) => { + return Err(err); + } + }; + Ok(SegmentManager::new(Box::new(sysdb))) + } +}