Skip to content

Commit

Permalink
[ENH] Rust SysDB (#1529)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Fix a typo in ingest
 - New functionality
	 - Adds the rust sysdb

Some cleanup is TODO - will address in this PR 

## Test plan
*How are these changes tested?*
No tests were added. I will add an integration test for this.
- [x] Tests pass locally with `cargo test`

## Documentation Changes
None
  • Loading branch information
HammadB authored Jan 16, 2024
1 parent 2063588 commit fc7b3ea
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 2 deletions.
2 changes: 1 addition & 1 deletion rust/worker/chroma_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# for now we nest it in the worker directory

worker:
my_ip: "10.244.0.85"
my_ip: "10.244.0.90"
num_indexing_threads: 4
pulsar_url: "pulsar://127.0.0.1:6650"
pulsar_tenant: "public"
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/ingest/ingest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ impl Handler<Memberlist> for Ingest {
// Bookkeep the handle so we can shut the stream down later
match self.topic_to_handle.write() {
Ok(mut topic_to_handle) => {
topic_to_handle.insert("test".to_string(), handle);
topic_to_handle.insert(topic.to_string(), handle);
}
Err(err) => {
// TODO: log error and handle lock poisoning
Expand Down
1 change: 1 addition & 0 deletions rust/worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod errors;
mod index;
mod ingest;
mod memberlist;
mod sysdb;
mod system;
mod types;

Expand Down
1 change: 1 addition & 0 deletions rust/worker/src/sysdb/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod sysdb;
237 changes: 237 additions & 0 deletions rust/worker/src/sysdb/sysdb.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
use async_trait::async_trait;
use uuid::Uuid;

use crate::chroma_proto;
use crate::types::{CollectionConversionError, SegmentConversionError};
use crate::{
chroma_proto::sys_db_client,
errors::{ChromaError, ErrorCodes},
types::{Collection, Segment, SegmentScope},
};
use thiserror::Error;

const DEFAULT_DATBASE: &str = "default_database";
const DEFAULT_TENANT: &str = "default_tenant";

#[async_trait]
pub(crate) trait SysDb: Send + Sync + SysDbClone {
async fn get_collections(
&mut self,
collection_id: Option<Uuid>,
topic: Option<String>,
name: Option<String>,
tenant: Option<String>,
database: Option<String>,
) -> Result<Vec<Collection>, GetCollectionsError>;

async fn get_segments(
&mut self,
id: Option<Uuid>,
r#type: Option<String>,
scope: Option<SegmentScope>,
topic: Option<String>,
collection: Option<Uuid>,
) -> Result<Vec<Segment>, GetSegmentsError>;
}

// We'd like to be able to clone the trait object, so we need to use the
// "clone box" pattern. See https://stackoverflow.com/questions/30353462/how-to-clone-a-struct-storing-a-boxed-trait-object#comment48814207_30353928
// https://chat.openai.com/share/b3eae92f-0b80-446f-b79d-6287762a2420
pub(crate) trait SysDbClone {
fn clone_box(&self) -> Box<dyn SysDb>;
}

impl<T> SysDbClone for T
where
T: 'static + SysDb + Clone,
{
fn clone_box(&self) -> Box<dyn SysDb> {
Box::new(self.clone())
}
}

impl Clone for Box<dyn SysDb> {
fn clone(&self) -> Box<dyn SysDb> {
self.clone_box()
}
}

#[derive(Clone)]
// Since this uses tonic transport channel, cloning is cheap. Each client only supports
// one inflight request at a time, so we need to clone the client for each requester.
pub(crate) struct GrpcSysDb {
client: sys_db_client::SysDbClient<tonic::transport::Channel>,
}

impl GrpcSysDb {
pub(crate) async fn new() -> Self {
let client = sys_db_client::SysDbClient::connect("http://[::1]:50051").await;
match client {
Ok(client) => {
return GrpcSysDb { client: client };
}
Err(e) => {
// TODO: config error
panic!("Failed to connect to sysdb: {}", e);
}
}
}
}

#[async_trait]
impl SysDb for GrpcSysDb {
async fn get_collections(
&mut self,
collection_id: Option<Uuid>,
topic: Option<String>,
name: Option<String>,
tenant: Option<String>,
database: Option<String>,
) -> Result<Vec<Collection>, GetCollectionsError> {
// TODO: move off of status into our own error type
let collection_id_str;
match collection_id {
Some(id) => {
collection_id_str = Some(id.to_string());
}
None => {
collection_id_str = None;
}
}

let res = self
.client
.get_collections(chroma_proto::GetCollectionsRequest {
id: collection_id_str,
topic: topic,
name: name,
tenant: if tenant.is_some() {
tenant.unwrap()
} else {
DEFAULT_TENANT.to_string()
},
database: if database.is_some() {
database.unwrap()
} else {
DEFAULT_DATBASE.to_string()
},
})
.await;

match res {
Ok(res) => {
let collections = res.into_inner().collections;

let collections = collections
.into_iter()
.map(|proto_collection| proto_collection.try_into())
.collect::<Result<Vec<Collection>, CollectionConversionError>>();

match collections {
Ok(collections) => {
return Ok(collections);
}
Err(e) => {
return Err(GetCollectionsError::ConversionError(e));
}
}
}
Err(e) => {
return Err(GetCollectionsError::FailedToGetCollections(e));
}
}
}

async fn get_segments(
&mut self,
id: Option<Uuid>,
r#type: Option<String>,
scope: Option<SegmentScope>,
topic: Option<String>,
collection: Option<Uuid>,
) -> Result<Vec<Segment>, GetSegmentsError> {
let res = self
.client
.get_segments(chroma_proto::GetSegmentsRequest {
// TODO: modularize
id: if id.is_some() {
Some(id.unwrap().to_string())
} else {
None
},
r#type: r#type,
scope: if scope.is_some() {
Some(scope.unwrap() as i32)
} else {
None
},
topic: topic,
collection: if collection.is_some() {
Some(collection.unwrap().to_string())
} else {
None
},
})
.await;
println!("get_segments: {:?}", res);
match res {
Ok(res) => {
let segments = res.into_inner().segments;
let converted_segments = segments
.into_iter()
.map(|proto_segment| proto_segment.try_into())
.collect::<Result<Vec<Segment>, SegmentConversionError>>();

match converted_segments {
Ok(segments) => {
println!("returning segments");
return Ok(segments);
}
Err(e) => {
println!("failed to convert segments: {}", e);
return Err(GetSegmentsError::ConversionError(e));
}
}
}
Err(e) => {
return Err(GetSegmentsError::FailedToGetSegments(e));
}
}
}
}

#[derive(Error, Debug)]
// TODO: This should use our sysdb errors from the proto definition
// We will have to do an error uniformization pass at some point
pub(crate) enum GetCollectionsError {
#[error("Failed to fetch")]
FailedToGetCollections(#[from] tonic::Status),
#[error("Failed to convert proto collection")]
ConversionError(#[from] CollectionConversionError),
}

impl ChromaError for GetCollectionsError {
fn code(&self) -> ErrorCodes {
match self {
GetCollectionsError::FailedToGetCollections(_) => ErrorCodes::Internal,
GetCollectionsError::ConversionError(_) => ErrorCodes::Internal,
}
}
}

#[derive(Error, Debug)]
pub(crate) enum GetSegmentsError {
#[error("Failed to fetch")]
FailedToGetSegments(#[from] tonic::Status),
#[error("Failed to convert proto segment")]
ConversionError(#[from] SegmentConversionError),
}

impl ChromaError for GetSegmentsError {
fn code(&self) -> ErrorCodes {
match self {
GetSegmentsError::FailedToGetSegments(_) => ErrorCodes::Internal,
GetSegmentsError::ConversionError(_) => ErrorCodes::Internal,
}
}
}

0 comments on commit fc7b3ea

Please sign in to comment.