From b5ed84af50a68414a3e2cb2620513b91e13d8a3c Mon Sep 17 00:00:00 2001 From: hammadb Date: Thu, 14 Dec 2023 16:49:28 -0800 Subject: [PATCH] [ENH] Rush SysDB --- rust/worker/src/ingest/ingest.rs | 2 +- rust/worker/src/lib.rs | 1 + rust/worker/src/sysdb/mod.rs | 1 + rust/worker/src/sysdb/sysdb.rs | 237 +++++++++++++++++++++++++++++++ 4 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 rust/worker/src/sysdb/mod.rs create mode 100644 rust/worker/src/sysdb/sysdb.rs diff --git a/rust/worker/src/ingest/ingest.rs b/rust/worker/src/ingest/ingest.rs index c22e6cc27fe..f1c53de47b0 100644 --- a/rust/worker/src/ingest/ingest.rs +++ b/rust/worker/src/ingest/ingest.rs @@ -207,7 +207,7 @@ impl Handler for Ingest { // Bookkeep the handle so we can shut the stream down later match self.topic_to_handle.write() { Ok(mut topic_to_handle) => { - topic_to_handle.insert("test".to_string(), handle); + topic_to_handle.insert(topic.to_string(), handle); } Err(err) => { // TODO: log error and handle lock poisoning diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index 936ebec50bb..d8cda20d06c 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -4,6 +4,7 @@ mod errors; mod index; mod ingest; mod memberlist; +mod sysdb; mod system; mod types; diff --git a/rust/worker/src/sysdb/mod.rs b/rust/worker/src/sysdb/mod.rs new file mode 100644 index 00000000000..fe0e0423079 --- /dev/null +++ b/rust/worker/src/sysdb/mod.rs @@ -0,0 +1 @@ +mod sysdb; diff --git a/rust/worker/src/sysdb/sysdb.rs b/rust/worker/src/sysdb/sysdb.rs new file mode 100644 index 00000000000..546dd8e231a --- /dev/null +++ b/rust/worker/src/sysdb/sysdb.rs @@ -0,0 +1,237 @@ +use async_trait::async_trait; +use uuid::Uuid; + +use crate::chroma_proto; +use crate::types::{CollectionConversionError, SegmentConversionError}; +use crate::{ + chroma_proto::sys_db_client, + errors::{ChromaError, ErrorCodes}, + types::{Collection, Segment, SegmentScope}, +}; +use thiserror::Error; + +const DEFAULT_DATBASE: &str = "default_database"; +const DEFAULT_TENANT: &str = "default_tenant"; + +#[async_trait] +pub(crate) trait SysDb: Send + Sync + SysDbClone { + async fn get_collections( + &mut self, + collection_id: Option, + topic: Option, + name: Option, + tenant: Option, + database: Option, + ) -> Result, GetCollectionsError>; + + async fn get_segments( + &mut self, + id: Option, + r#type: Option, + scope: Option, + topic: Option, + collection: Option, + ) -> Result, GetSegmentsError>; +} + +// We'd like to be able to clone the trait object, so we need to use the +// "clone box" pattern. See https://stackoverflow.com/questions/30353462/how-to-clone-a-struct-storing-a-boxed-trait-object#comment48814207_30353928 +// https://chat.openai.com/share/b3eae92f-0b80-446f-b79d-6287762a2420 +pub(crate) trait SysDbClone { + fn clone_box(&self) -> Box; +} + +impl SysDbClone for T +where + T: 'static + SysDb + Clone, +{ + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } +} + +#[derive(Clone)] +// Since this uses tonic transport channel, cloning is cheap. Each client only supports +// one inflight request at a time, so we need to clone the client for each requester. +pub(crate) struct GrpcSysDb { + client: sys_db_client::SysDbClient, +} + +impl GrpcSysDb { + pub(crate) async fn new() -> Self { + let client = sys_db_client::SysDbClient::connect("http://[::1]:50051").await; + match client { + Ok(client) => { + return GrpcSysDb { client: client }; + } + Err(e) => { + // TODO: config error + panic!("Failed to connect to sysdb: {}", e); + } + } + } +} + +#[async_trait] +impl SysDb for GrpcSysDb { + async fn get_collections( + &mut self, + collection_id: Option, + topic: Option, + name: Option, + tenant: Option, + database: Option, + ) -> Result, GetCollectionsError> { + // TODO: move off of status into our own error type + let collection_id_str; + match collection_id { + Some(id) => { + collection_id_str = Some(id.to_string()); + } + None => { + collection_id_str = None; + } + } + + let res = self + .client + .get_collections(chroma_proto::GetCollectionsRequest { + id: collection_id_str, + topic: topic, + name: name, + tenant: if tenant.is_some() { + tenant.unwrap() + } else { + DEFAULT_TENANT.to_string() + }, + database: if database.is_some() { + database.unwrap() + } else { + DEFAULT_DATBASE.to_string() + }, + }) + .await; + + match res { + Ok(res) => { + let collections = res.into_inner().collections; + + let collections = collections + .into_iter() + .map(|proto_collection| proto_collection.try_into()) + .collect::, CollectionConversionError>>(); + + match collections { + Ok(collections) => { + return Ok(collections); + } + Err(e) => { + return Err(GetCollectionsError::ConversionError(e)); + } + } + } + Err(e) => { + return Err(GetCollectionsError::FailedToGetCollections(e)); + } + } + } + + async fn get_segments( + &mut self, + id: Option, + r#type: Option, + scope: Option, + topic: Option, + collection: Option, + ) -> Result, GetSegmentsError> { + let res = self + .client + .get_segments(chroma_proto::GetSegmentsRequest { + // TODO: modularize + id: if id.is_some() { + Some(id.unwrap().to_string()) + } else { + None + }, + r#type: r#type, + scope: if scope.is_some() { + Some(scope.unwrap() as i32) + } else { + None + }, + topic: topic, + collection: if collection.is_some() { + Some(collection.unwrap().to_string()) + } else { + None + }, + }) + .await; + println!("get_segments: {:?}", res); + match res { + Ok(res) => { + let segments = res.into_inner().segments; + let converted_segments = segments + .into_iter() + .map(|proto_segment| proto_segment.try_into()) + .collect::, SegmentConversionError>>(); + + match converted_segments { + Ok(segments) => { + println!("returning segments"); + return Ok(segments); + } + Err(e) => { + println!("failed to convert segments: {}", e); + return Err(GetSegmentsError::ConversionError(e)); + } + } + } + Err(e) => { + return Err(GetSegmentsError::FailedToGetSegments(e)); + } + } + } +} + +#[derive(Error, Debug)] +// TODO: This should use our sysdb errors from the proto definition +// We will have to do an error uniformization pass at some point +pub(crate) enum GetCollectionsError { + #[error("Failed to fetch")] + FailedToGetCollections(#[from] tonic::Status), + #[error("Failed to convert proto collection")] + ConversionError(#[from] CollectionConversionError), +} + +impl ChromaError for GetCollectionsError { + fn code(&self) -> ErrorCodes { + match self { + GetCollectionsError::FailedToGetCollections(_) => ErrorCodes::Internal, + GetCollectionsError::ConversionError(_) => ErrorCodes::Internal, + } + } +} + +#[derive(Error, Debug)] +pub(crate) enum GetSegmentsError { + #[error("Failed to fetch")] + FailedToGetSegments(#[from] tonic::Status), + #[error("Failed to convert proto segment")] + ConversionError(#[from] SegmentConversionError), +} + +impl ChromaError for GetSegmentsError { + fn code(&self) -> ErrorCodes { + match self { + GetSegmentsError::FailedToGetSegments(_) => ErrorCodes::Internal, + GetSegmentsError::ConversionError(_) => ErrorCodes::Internal, + } + } +}