Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Jan 15, 2024
1 parent ea38dca commit 1ac8043
Show file tree
Hide file tree
Showing 14 changed files with 294 additions and 76 deletions.
1 change: 1 addition & 0 deletions rust/worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ schemars = "0.8.16"
kube = { version = "0.87.1", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.20.0", features = ["latest"] }
bytes = "1.5.0"
parking_lot = "0.12.1"

[build-dependencies]
tonic-build = "0.10"
Expand Down
2 changes: 2 additions & 0 deletions rust/worker/chroma_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ worker:
Grpc:
host: "localhost"
port: 50051
segment_manager:
storage_path: "./tmp/segment_manager/"
9 changes: 9 additions & 0 deletions rust/worker/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pub(crate) struct WorkerConfig {
pub(crate) memberlist_provider: crate::memberlist::config::MemberlistProviderConfig,
pub(crate) ingest: crate::ingest::config::IngestConfig,
pub(crate) sysdb: crate::sysdb::config::SysDbConfig,
pub(crate) segment_manager: crate::segment::config::SegmentManagerConfig,
}

/// # Description
Expand Down Expand Up @@ -151,6 +152,8 @@ mod tests {
Grpc:
host: "localhost"
port: 50051
segment_manager:
storage_path: "/tmp"
"#,
);
Expand Down Expand Up @@ -190,6 +193,8 @@ mod tests {
Grpc:
host: "localhost"
port: 50051
segment_manager:
storage_path: "/tmp"
"#,
);
Expand Down Expand Up @@ -244,6 +249,8 @@ mod tests {
Grpc:
host: "localhost"
port: 50051
segment_manager:
storage_path: "/tmp"
"#,
);
Expand Down Expand Up @@ -279,6 +286,8 @@ mod tests {
Grpc:
host: "localhost"
port: 50051
segment_manager:
storage_path: "/tmp"
"#,
);
let config = RootConfig::load();
Expand Down
81 changes: 81 additions & 0 deletions rust/worker/src/index/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::ffi::{c_char, c_int};
use crate::errors::{ChromaError, ErrorCodes};

use super::{Index, IndexConfig, PersistentIndex};
use crate::types::{Metadata, MetadataValue, MetadataValueConversionError, Segment};
use thiserror::Error;

// https://doc.rust-lang.org/nomicon/ffi.html#representing-opaque-structs
Expand All @@ -29,6 +30,86 @@ pub(crate) struct HnswIndexConfig {
pub(crate) persist_path: String,
}

#[derive(Error, Debug)]
pub(crate) enum HnswIndexFromSegmentError {
#[error("Missing config `{0}`")]
MissingConfig(String),
}

impl ChromaError for HnswIndexFromSegmentError {
fn code(&self) -> ErrorCodes {
crate::errors::ErrorCodes::InvalidArgument
}
}

impl HnswIndexConfig {
pub(crate) fn from_segment(
segment: &Segment,
persist_path: &std::path::Path,
) -> Result<HnswIndexConfig, Box<dyn ChromaError>> {
let persist_path = match persist_path.to_str() {
Some(persist_path) => persist_path,
None => {
return Err(Box::new(HnswIndexFromSegmentError::MissingConfig(
"persist_path".to_string(),
)))
}
};
let metadata = match &segment.metadata {
Some(metadata) => metadata,
None => {
// TODO: This should error, but the configuration is not stored correctly
// after the configuration is refactored to be always stored and doesn't rely on defaults we can fix this
return Ok(HnswIndexConfig {
max_elements: 1000,
m: 16,
ef_construction: 100,
ef_search: 10,
random_seed: 0,
persist_path: persist_path.to_string(),
});
// return Err(Box::new(HnswIndexFromSegmentError::MissingConfig(
// "metadata".to_string(),
// )))
}
};

fn get_metadata_value_as<'a, T>(
metadata: &'a Metadata,
key: &str,
) -> Result<T, Box<dyn ChromaError>>
where
T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
{
let res = match metadata.get(key) {
Some(value) => T::try_from(value),
None => {
return Err(Box::new(HnswIndexFromSegmentError::MissingConfig(
key.to_string(),
)))
}
};
match res {
Ok(value) => Ok(value),
Err(e) => Err(Box::new(e)),
}
}

let max_elements = get_metadata_value_as::<i32>(metadata, "hsnw:max_elements")?;
let m = get_metadata_value_as::<i32>(metadata, "hnsw:m")?;
let ef_construction = get_metadata_value_as::<i32>(metadata, "hnsw:ef_construction")?;
let ef_search = get_metadata_value_as::<i32>(metadata, "hnsw:ef_search")?;
return Ok(HnswIndexConfig {
max_elements: max_elements as usize,
m: m as usize,
ef_construction: ef_construction as usize,
ef_search: ef_search as usize,
random_seed: 0,
persist_path: persist_path.to_string(),
});
}
}

#[repr(C)]
/// The HnswIndex struct.
/// # Description
Expand Down
37 changes: 37 additions & 0 deletions rust/worker/src/index/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::errors::{ChromaError, ErrorCodes};
use crate::types::{MetadataValue, Segment};
use thiserror::Error;

#[derive(Clone, Debug)]
Expand All @@ -7,6 +8,42 @@ pub(crate) struct IndexConfig {
pub(crate) distance_function: DistanceFunction,
}

#[derive(Error, Debug)]
pub(crate) enum IndexConfigFromSegmentError {
#[error("No space defined")]
NoSpaceDefined,
}

impl ChromaError for IndexConfigFromSegmentError {
fn code(&self) -> ErrorCodes {
match self {
IndexConfigFromSegmentError::NoSpaceDefined => ErrorCodes::InvalidArgument,
}
}
}

impl IndexConfig {
pub(crate) fn from_segment(
segment: &Segment,
dimensionality: i32,
) -> Result<Self, Box<dyn ChromaError>> {
let space = match segment.metadata {
Some(ref metadata) => match metadata.get("hnsw:space") {
Some(MetadataValue::Str(space)) => space,
_ => "l2",
},
None => "l2",
};
match DistanceFunction::try_from(space) {
Ok(distance_function) => Ok(IndexConfig {
dimensionality: dimensionality,
distance_function: distance_function,
}),
Err(e) => Err(Box::new(e)),
}
}
}

/// The index trait.
/// # Description
/// This trait defines the interface for a KNN index.
Expand Down
16 changes: 8 additions & 8 deletions rust/worker/src/ingest/ingest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub(crate) struct Ingest {
pulsar_namespace: String,
pulsar: Pulsar<TokioExecutor>,
sysdb: Box<dyn SysDb>,
scheduler: Option<Box<dyn Receiver<(String, Arc<EmbeddingRecord>)>>>,
scheduler: Option<Box<dyn Receiver<(String, Box<EmbeddingRecord>)>>>,
}

impl Component for Ingest {
Expand Down Expand Up @@ -133,7 +133,7 @@ impl Ingest {

pub(crate) fn subscribe(
&mut self,
scheduler: Box<dyn Receiver<(String, Arc<EmbeddingRecord>)>>,
scheduler: Box<dyn Receiver<(String, Box<EmbeddingRecord>)>>,
) {
self.scheduler = Some(scheduler);
}
Expand Down Expand Up @@ -277,7 +277,7 @@ impl DeserializeMessage for chroma_proto::SubmitEmbeddingRecord {
struct PulsarIngestTopic {
consumer: RwLock<Option<Consumer<chroma_proto::SubmitEmbeddingRecord, TokioExecutor>>>,
sysdb: Box<dyn SysDb>,
scheduler: Box<dyn Receiver<(String, Arc<EmbeddingRecord>)>>,
scheduler: Box<dyn Receiver<(String, Box<EmbeddingRecord>)>>,
}

impl Debug for PulsarIngestTopic {
Expand All @@ -290,7 +290,7 @@ impl PulsarIngestTopic {
fn new(
consumer: Consumer<chroma_proto::SubmitEmbeddingRecord, TokioExecutor>,
sysdb: Box<dyn SysDb>,
scheduler: Box<dyn Receiver<(String, Arc<EmbeddingRecord>)>>,
scheduler: Box<dyn Receiver<(String, Box<EmbeddingRecord>)>>,
) -> Self {
PulsarIngestTopic {
consumer: RwLock::new(Some(consumer)),
Expand Down Expand Up @@ -327,7 +327,7 @@ impl Component for PulsarIngestTopic {
(proto_embedding_record, seq_id).try_into();
match embedding_record {
Ok(embedding_record) => {
return Some(Arc::new(embedding_record));
return Some(Box::new(embedding_record));
}
Err(err) => {
// TODO: Handle and log
Expand All @@ -348,10 +348,10 @@ impl Component for PulsarIngestTopic {
}

#[async_trait]
impl Handler<Option<Arc<EmbeddingRecord>>> for PulsarIngestTopic {
impl Handler<Option<Box<EmbeddingRecord>>> for PulsarIngestTopic {
async fn handle(
&mut self,
message: Option<Arc<EmbeddingRecord>>,
message: Option<Box<EmbeddingRecord>>,
_ctx: &ComponentContext<PulsarIngestTopic>,
) -> () {
// Use the sysdb to tenant id for the embedding record
Expand Down Expand Up @@ -396,4 +396,4 @@ impl Handler<Option<Arc<EmbeddingRecord>>> for PulsarIngestTopic {
}

#[async_trait]
impl StreamHandler<Option<Arc<EmbeddingRecord>>> for PulsarIngestTopic {}
impl StreamHandler<Option<Box<EmbeddingRecord>>> for PulsarIngestTopic {}
14 changes: 7 additions & 7 deletions rust/worker/src/ingest/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ pub(crate) struct RoundRobinScheduler {
// The segment manager to schedule to, a segment manager is a component
// segment_manager: SegmentManager
curr_wake_up: Option<tokio::sync::oneshot::Sender<WakeMessage>>,
tenant_to_queue: HashMap<String, tokio::sync::mpsc::Sender<Arc<EmbeddingRecord>>>,
tenant_to_queue: HashMap<String, tokio::sync::mpsc::Sender<Box<EmbeddingRecord>>>,
new_tenant_channel: Option<tokio::sync::mpsc::Sender<NewTenantMessage>>,
subscribers: Option<Vec<Box<dyn Receiver<Arc<EmbeddingRecord>>>>>,
subscribers: Option<Vec<Box<dyn Receiver<Box<EmbeddingRecord>>>>>,
}

impl Debug for RoundRobinScheduler {
Expand All @@ -39,7 +39,7 @@ impl RoundRobinScheduler {
}
}

pub(crate) fn subscribe(&mut self, subscriber: Box<dyn Receiver<Arc<EmbeddingRecord>>>) {
pub(crate) fn subscribe(&mut self, subscriber: Box<dyn Receiver<Box<EmbeddingRecord>>>) {
match self.subscribers {
Some(ref mut subscribers) => {
subscribers.push(subscriber);
Expand Down Expand Up @@ -70,7 +70,7 @@ impl Component for RoundRobinScheduler {
tokio::spawn(async move {
let mut tenant_queues: HashMap<
String,
tokio::sync::mpsc::Receiver<Arc<EmbeddingRecord>>,
tokio::sync::mpsc::Receiver<Box<EmbeddingRecord>>,
> = HashMap::new();
loop {
// TODO: handle cancellation
Expand Down Expand Up @@ -136,10 +136,10 @@ impl Component for RoundRobinScheduler {
}

#[async_trait]
impl Handler<(String, Arc<EmbeddingRecord>)> for RoundRobinScheduler {
impl Handler<(String, Box<EmbeddingRecord>)> for RoundRobinScheduler {
async fn handle(
&mut self,
message: (String, Arc<EmbeddingRecord>),
message: (String, Box<EmbeddingRecord>),
_ctx: &ComponentContext<Self>,
) {
let (tenant, embedding_record) = message;
Expand Down Expand Up @@ -208,5 +208,5 @@ struct SleepMessage {

struct NewTenantMessage {
tenant: String,
channel: tokio::sync::mpsc::Receiver<Arc<EmbeddingRecord>>,
channel: tokio::sync::mpsc::Receiver<Box<EmbeddingRecord>>,
}
9 changes: 9 additions & 0 deletions rust/worker/src/segment/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use serde::Deserialize;

/// The configuration for the custom resource memberlist provider.
/// # Fields
/// - storage_path: The path to use for temporary storage in the segment manager, if needed.
#[derive(Deserialize)]
pub(crate) struct SegmentManagerConfig {
pub(crate) storage_path: String,
}
29 changes: 24 additions & 5 deletions rust/worker/src/segment/distributed_hnsw_segment.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use std::collections::HashMap;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use crate::errors::ChromaError;
use crate::index::{HnswIndex, HnswIndexConfig, Index, IndexConfig};
use crate::types::{EmbeddingRecord, Operation};
use crate::types::{EmbeddingRecord, Operation, Segment};

pub(crate) struct DistributedHNSWSegment {
index: Arc<RwLock<HnswIndex>>,
id: AtomicUsize,
user_id_to_id: Arc<RwLock<HashMap<String, usize>>>,
index_config: IndexConfig,
hnsw_config: HnswIndexConfig,
}
Expand All @@ -30,22 +32,39 @@ impl DistributedHNSWSegment {
return Ok(DistributedHNSWSegment {
index: index,
id: AtomicUsize::new(0),
user_id_to_id: Arc::new(RwLock::new(HashMap::new())),
index_config: index_config,
hnsw_config,
});
}

pub(crate) fn write_records(&mut self, records: Vec<Box<EmbeddingRecord>>) {
pub(crate) fn from_segment(
segment: &Segment,
persist_path: &std::path::Path,
dimensionality: usize,
) -> Result<Box<DistributedHNSWSegment>, Box<dyn ChromaError>> {
let index_config = IndexConfig::from_segment(&segment, dimensionality as i32)?;
let hnsw_config = HnswIndexConfig::from_segment(segment, persist_path)?;
Ok(Box::new(DistributedHNSWSegment::new(
index_config,
hnsw_config,
)?))
}

pub(crate) fn write_records(&self, records: Vec<Box<EmbeddingRecord>>) {
for record in records {
let op = Operation::try_from(record.operation);
match op {
Ok(Operation::Add) => {
// TODO: make lock xor lock
match record.embedding {
match &record.embedding {
Some(vector) => {
let next_id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
println!("Adding item: {}", next_id);
self.index.write().add(next_id, &vector);
self.user_id_to_id
.write()
.insert(record.id.clone(), next_id);
println!("DIS SEGMENT Adding item: {}", next_id);
self.index.read().add(next_id, &vector);
}
None => {
// TODO: log an error
Expand Down
1 change: 1 addition & 0 deletions rust/worker/src/segment/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub(crate) mod config;
mod distributed_hnsw_segment;
mod segment_ingestor;
mod segment_manager;
Expand Down
Loading

0 comments on commit 1ac8043

Please sign in to comment.