Skip to content

Commit

Permalink
Add segment manager, hnsw segment
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Jan 15, 2024
1 parent d7e7293 commit ea38dca
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 9 deletions.
1 change: 1 addition & 0 deletions rust/worker/src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ mod types;
mod utils;

// Re-export types
pub(crate) use hnsw::*;
pub(crate) use types::*;
6 changes: 5 additions & 1 deletion rust/worker/src/ingest/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ impl Component for RoundRobinScheduler {
// Randomly pick a subscriber to send the message to
// This serves as a crude load balancing between available threads
// Future improvements here could be
// -
// - Use a work stealing scheduler
// - Use rayon
// - We need to enforce partial order over writes to a given key
// so we need a mechanism to ensure that all writes to a given key
// occur in order
let mut subscriber = None;
{
let mut rng = rand::thread_rng();
Expand Down
10 changes: 9 additions & 1 deletion rust/worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,18 @@ pub async fn worker_entrypoint() {

let mut scheduler = ingest::RoundRobinScheduler::new();

let segment_manager = match segment::SegmentManager::try_from_config(&config.worker).await {
Ok(segment_manager) => segment_manager,
Err(err) => {
println!("Failed to create segment manager component: {:?}", err);
return;
}
};

let mut segment_ingestor_receivers =
Vec::with_capacity(config.worker.num_indexing_threads as usize);
for _ in 0..config.worker.num_indexing_threads {
let segment_ingestor = segment::SegmentIngestor::new();
let segment_ingestor = segment::SegmentIngestor::new(segment_manager.clone());
let segment_ingestor_handle = system.start_component(segment_ingestor);
let recv = segment_ingestor_handle.receiver();
segment_ingestor_receivers.push(recv);
Expand Down
65 changes: 65 additions & 0 deletions rust/worker/src/segment/distributed_hnsw_segment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use crate::errors::ChromaError;
use crate::index::{HnswIndex, HnswIndexConfig, Index, IndexConfig};
use crate::types::{EmbeddingRecord, Operation};

pub(crate) struct DistributedHNSWSegment {
index: Arc<RwLock<HnswIndex>>,
id: AtomicUsize,
index_config: IndexConfig,
hnsw_config: HnswIndexConfig,
}

impl DistributedHNSWSegment {
pub(crate) fn new(
index_config: IndexConfig,
hnsw_config: HnswIndexConfig,
) -> Result<Self, Box<dyn ChromaError>> {
let hnsw_index = HnswIndex::init(&index_config, Some(&hnsw_config));
let hnsw_index = match hnsw_index {
Ok(index) => index,
Err(e) => {
// TODO: log + handle an error that we failed to init the index
return Err(e);
}
};
let index = Arc::new(RwLock::new(hnsw_index));
return Ok(DistributedHNSWSegment {
index: index,
id: AtomicUsize::new(0),
index_config: index_config,
hnsw_config,
});
}

pub(crate) fn write_records(&mut self, records: Vec<Box<EmbeddingRecord>>) {
for record in records {
let op = Operation::try_from(record.operation);
match op {
Ok(Operation::Add) => {
// TODO: make lock xor lock
match record.embedding {
Some(vector) => {
let next_id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
println!("Adding item: {}", next_id);
self.index.write().add(next_id, &vector);
}
None => {
// TODO: log an error
println!("No vector found in record");
}
}
}
Ok(Operation::Upsert) => {}
Ok(Operation::Update) => {}
Ok(Operation::Delete) => {}
Err(_) => {
println!("Error parsing operation");
}
}
}
}
}
5 changes: 4 additions & 1 deletion rust/worker/src/segment/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod distributed_hnsw_segment;
mod segment_ingestor;
mod segment_manager;

pub use segment_ingestor::*;
pub(crate) use segment_ingestor::*;
pub(crate) use segment_manager::*;
16 changes: 10 additions & 6 deletions rust/worker/src/segment/segment_ingestor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ use crate::{
types::EmbeddingRecord,
};

pub(crate) struct SegmentIngestor {}
use super::segment_manager::{self, SegmentManager};

pub(crate) struct SegmentIngestor {
segment_manager: SegmentManager,
}

impl Component for SegmentIngestor {
fn queue_size(&self) -> usize {
Expand All @@ -28,17 +32,17 @@ impl Debug for SegmentIngestor {
}

impl SegmentIngestor {
pub(crate) fn new() -> Self {
SegmentIngestor {}
pub(crate) fn new(segment_manager: SegmentManager) -> Self {
SegmentIngestor {
segment_manager: segment_manager,
}
}
}

#[async_trait]
impl Handler<Arc<EmbeddingRecord>> for SegmentIngestor {
async fn handle(&mut self, message: Arc<EmbeddingRecord>, ctx: &ComponentContext<Self>) {
println!("INGEST: ID of embedding is {}", message.id);
// let segment_manager = ctx.system.get_segment_manager();
// let segment = segment_manager.get_segment(&tenant);
// segment.ingest(embedding);
self.segment_manager.write_record(message).await;
}
}
159 changes: 159 additions & 0 deletions rust/worker/src/segment/segment_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use crate::{
config::{Configurable, WorkerConfig},
errors::ChromaError,
sysdb::sysdb::{GrpcSysDb, SysDb},
};
use async_trait::async_trait;
use parking_lot::{
MappedRwLockReadGuard, RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard,
};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;

use super::distributed_hnsw_segment::DistributedHNSWSegment;
use crate::types::{EmbeddingRecord, MetadataValue, Segment, SegmentScope};

#[derive(Clone)]
pub(crate) struct SegmentManager {
inner: Arc<Inner>,
sysdb: Box<dyn SysDb>,
}

struct Inner {
vector_segments: RwLock<HashMap<Uuid, Box<DistributedHNSWSegment>>>,
collection_to_segment_cache: RwLock<HashMap<Uuid, Vec<Segment>>>,
}

impl SegmentManager {
pub(crate) fn new(sysdb: Box<dyn SysDb>) -> Self {
SegmentManager {
inner: Arc::new(Inner {
vector_segments: RwLock::new(HashMap::new()),
collection_to_segment_cache: RwLock::new(HashMap::new()),
}),
sysdb: sysdb,
}
}

pub(crate) async fn write_record(&mut self, record: Arc<EmbeddingRecord>) {
println!(
"Manager is writing record for collection: {}",
record.collection_id
);
let collection_id = record.collection_id;
let mut target_segment_id = None;

// TODO: don't assume 1:1 mapping between collection and segment
{
let segments = self.get_segments(&collection_id).await;
// For now we assume segment is 1:1 with collection
target_segment_id = match segments {
Ok(segments) => {
if segments.len() == 0 {
return; // TODO: handle no segment found
}
Some(segments[0].id)
}
Err(_) => None,
};
}

if target_segment_id.is_none() {
return; // TODO: handle no segment found
}
// let target_segment_id = target_segment_id.unwrap();
println!("Writing record to segment: {}", target_segment_id.unwrap());

// let segment_cache = self.inner.vector_segments.upgradable_read();
// match segment_cache.get(&target_segment_id) {
// Some(segment) => {
// segment.write_records(vec![record]);
// }
// None => {
// let mut segment_cache = RwLockUpgradableReadGuard::upgrade(segment_cache);
// // Parse metadata from the segment and hydrate the params for the segment
// let new_segment = Box::new(DistributedHNSWSegment::new(
// "ip".to_string(),
// 100000,
// "./test/".to_string(),
// 100,
// 10000,
// ));
// segment_cache.insert(target_segment_id.clone(), new_segment);
// let segment_cache = RwLockWriteGuard::downgrade(segment_cache);
// let segment = RwLockReadGuard::map(segment_cache, |cache| {
// return cache.get(&target_segment_id).unwrap();
// });
// segment.write_records(vec![record]);
// }
// }
}

async fn get_segments(
&mut self,
collection_uuid: &Uuid,
) -> Result<MappedRwLockReadGuard<Vec<Segment>>, &'static str> {
let cache_guard = self.inner.collection_to_segment_cache.read();
// This lets us return a reference to the segments with the lock. The caller is responsible
// dropping the lock.
let segments =
RwLockReadGuard::try_map(cache_guard, |cache: &HashMap<Uuid, Vec<Segment>>| {
return cache.get(&collection_uuid);
});
match segments {
Ok(segments) => {
return Ok(segments);
}
Err(_) => {
// Data was not in the cache, so we need to get it from the database
// Drop the lock since we need to upgrade it
// Mappable locks cannot be upgraded, so we need to drop the lock and re-acquire it
// https://github.com/Amanieu/parking_lot/issues/83
drop(segments);

let segments = self
.sysdb
.get_segments(
None,
None,
Some(SegmentScope::VECTOR),
None,
Some(collection_uuid.clone()),
)
.await;
match segments {
Ok(segments) => {
let mut cache_guard = self.inner.collection_to_segment_cache.write();
cache_guard.insert(collection_uuid.clone(), segments);
let cache_guard = RwLockWriteGuard::downgrade(cache_guard);
let segments = RwLockReadGuard::map(cache_guard, |cache| {
// This unwrap is safe because we just inserted the segments into the cache and currently,
// there is no way to remove segments from the cache.
return cache.get(&collection_uuid).unwrap();
});
return Ok(segments);
}
Err(e) => {
return Err("Failed to get segments for collection from SysDB");
}
}
}
}
}
}

#[async_trait]
impl Configurable for SegmentManager {
async fn try_from_config(worker_config: &WorkerConfig) -> Result<Self, Box<dyn ChromaError>> {
// TODO: Sysdb should have a dynamic resolution in sysdb
let sysdb = GrpcSysDb::try_from_config(worker_config).await;
let sysdb = match sysdb {
Ok(sysdb) => sysdb,
Err(err) => {
return Err(err);
}
};
Ok(SegmentManager::new(Box::new(sysdb)))
}
}

0 comments on commit ea38dca

Please sign in to comment.