From af37c9a1313e7a0d2423e75d6556138525087dcb Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Mon, 15 Jan 2024 17:09:31 -0800 Subject: [PATCH] [ENH] Add rust protobufs and conversion. Add build.rs, protobufs, and conversions (#1513) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Update dockerfile to use a fetched protoc since that is needed for protoc to include/ WKT - New functionality - Adds a build.rs so that we can build the protos into rust bindings - Adds a types/ folder with types and rust-y/idiomatic TryFrom conversions so that callers can just do .try_from() and get type inference. - Adds a macro for error type wrapping and impl'ing ChromaError on the macro. - All types are rexported from types/ so that the rest of the code can easily use it. ## Test plan *How are these changes tested?* - Add _very_ rudimentary tests for conversion. We should do a pass where we add some more rigorous conversion testing. - [x] Tests pass locally with `cargo test` ## Documentation Changes None required. --- .github/workflows/chroma-worker-test.yml | 2 + Cargo.lock | 33 ++++ rust/worker/Cargo.toml | 3 + rust/worker/Dockerfile | 8 +- rust/worker/build.rs | 10 + rust/worker/src/config.rs | 3 + rust/worker/src/lib.rs | 5 + rust/worker/src/types/collection.rs | 88 +++++++++ rust/worker/src/types/embedding_record.rs | 229 ++++++++++++++++++++++ rust/worker/src/types/metadata.rs | 229 ++++++++++++++++++++++ rust/worker/src/types/mod.rs | 19 ++ rust/worker/src/types/operation.rs | 73 +++++++ rust/worker/src/types/scalar_encoding.rs | 66 +++++++ rust/worker/src/types/segment.rs | 114 +++++++++++ rust/worker/src/types/segment_scope.rs | 70 +++++++ rust/worker/src/types/types.rs | 36 ++++ 16 files changed, 987 insertions(+), 1 deletion(-) create mode 100644 rust/worker/build.rs create mode 100644 rust/worker/src/types/collection.rs create mode 100644 rust/worker/src/types/embedding_record.rs create mode 100644 rust/worker/src/types/metadata.rs create mode 100644 rust/worker/src/types/mod.rs create mode 100644 rust/worker/src/types/operation.rs create mode 100644 rust/worker/src/types/scalar_encoding.rs create mode 100644 rust/worker/src/types/segment.rs create mode 100644 rust/worker/src/types/segment_scope.rs create mode 100644 rust/worker/src/types/types.rs diff --git a/.github/workflows/chroma-worker-test.yml b/.github/workflows/chroma-worker-test.yml index 5325f52fda4..2cfce1b6d4a 100644 --- a/.github/workflows/chroma-worker-test.yml +++ b/.github/workflows/chroma-worker-test.yml @@ -19,6 +19,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v3 + - name: Install Protoc + uses: arduino/setup-protoc@v2 - name: Build run: cargo build --verbose - name: Test diff --git a/Cargo.lock b/Cargo.lock index 8a2f24ca53b..8077c626d8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -580,6 +580,36 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9252111cf132ba0929b6f8e030cac2a24b507f3a4d6db6fb2896f27b354c714b" +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -1453,7 +1483,10 @@ dependencies = [ "cc", "figment", "murmur3", + "num-bigint", "num_cpus", + "prost", + "prost-types", "rand", "rayon", "serde", diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index d0a0bff6ded..c1c2776078a 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -5,6 +5,8 @@ edition = "2021" [dependencies] tonic = "0.10" +prost = "0.12" +prost-types = "0.12" tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } tokio-util = "0.7.10" rand = "0.8.5" @@ -16,6 +18,7 @@ serde = { version = "1.0.193", features = ["derive"] } num_cpus = "1.16.0" murmur3 = "0.5.2" thiserror = "1.0.50" +num-bigint = "0.4.4" [build-dependencies] tonic-build = "0.10" diff --git a/rust/worker/Dockerfile b/rust/worker/Dockerfile index 7beb21d2b28..9fec202fda1 100644 --- a/rust/worker/Dockerfile +++ b/rust/worker/Dockerfile @@ -1,8 +1,14 @@ FROM rust:1.74.1 as builder -WORKDIR /chroma +WORKDIR /chroma/ COPY . . +ENV PROTOC_ZIP=protoc-25.1-linux-x86_64.zip +RUN curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v25.1/$PROTOC_ZIP \ + && unzip -o $PROTOC_ZIP -d /usr/local bin/protoc \ + && unzip -o $PROTOC_ZIP -d /usr/local 'include/*' \ + && rm -f $PROTOC_ZIP + RUN cargo build # For now this runs cargo test since we have no main binary diff --git a/rust/worker/build.rs b/rust/worker/build.rs new file mode 100644 index 00000000000..78b226a0e0c --- /dev/null +++ b/rust/worker/build.rs @@ -0,0 +1,10 @@ +fn main() -> Result<(), Box> { + tonic_build::configure().compile( + &[ + "../../idl/chromadb/proto/chroma.proto", + "../../idl/chromadb/proto/coordinator.proto", + ], + &["../../idl/"], + )?; + Ok(()) +} diff --git a/rust/worker/src/config.rs b/rust/worker/src/config.rs index f2efa97df00..44ba38ab7b9 100644 --- a/rust/worker/src/config.rs +++ b/rust/worker/src/config.rs @@ -84,6 +84,9 @@ impl RootConfig { /// ## Description of parameters /// - my_ip: The IP address of the worker service. Used for memberlist assignment. Must be provided /// - num_indexing_threads: The number of indexing threads to use. If not provided, defaults to the number of cores on the machine. +/// - pulsar_tenant: The pulsar tenant to use. Must be provided. +/// - pulsar_namespace: The pulsar namespace to use. Must be provided. +/// - assignment_policy: The assignment policy to use. Must be provided. /// # Notes /// In order to set the enviroment variables, you must prefix them with CHROMA_WORKER__. /// For example, to set my_ip, you would set CHROMA_WORKER__MY_IP. diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index a9d10c436e2..d48649febd1 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -1,3 +1,8 @@ mod assignment; mod config; mod errors; +mod types; + +mod chroma_proto { + tonic::include_proto!("chroma"); +} diff --git a/rust/worker/src/types/collection.rs b/rust/worker/src/types/collection.rs new file mode 100644 index 00000000000..2dd495a5afc --- /dev/null +++ b/rust/worker/src/types/collection.rs @@ -0,0 +1,88 @@ +use super::{Metadata, MetadataValueConversionError}; +use crate::{ + chroma_proto, + errors::{ChromaError, ErrorCodes}, +}; +use thiserror::Error; +use uuid::Uuid; + +#[derive(Debug, PartialEq)] +pub(crate) struct Collection { + pub(crate) id: Uuid, + pub(crate) name: String, + pub(crate) topic: String, + pub(crate) metadata: Option, + pub(crate) dimension: Option, + pub(crate) tenant: String, + pub(crate) database: String, +} + +#[derive(Error, Debug)] +pub(crate) enum CollectionConversionError { + #[error("Invalid UUID")] + InvalidUuid, + #[error(transparent)] + MetadataValueConversionError(#[from] MetadataValueConversionError), +} + +impl ChromaError for CollectionConversionError { + fn code(&self) -> crate::errors::ErrorCodes { + match self { + CollectionConversionError::InvalidUuid => ErrorCodes::InvalidArgument, + CollectionConversionError::MetadataValueConversionError(e) => e.code(), + } + } +} + +impl TryFrom for Collection { + type Error = CollectionConversionError; + + fn try_from(proto_collection: chroma_proto::Collection) -> Result { + let collection_uuid = match Uuid::try_parse(&proto_collection.id) { + Ok(uuid) => uuid, + Err(_) => return Err(CollectionConversionError::InvalidUuid), + }; + let collection_metadata: Option = match proto_collection.metadata { + Some(proto_metadata) => match proto_metadata.try_into() { + Ok(metadata) => Some(metadata), + Err(e) => return Err(CollectionConversionError::MetadataValueConversionError(e)), + }, + None => None, + }; + Ok(Collection { + id: collection_uuid, + name: proto_collection.name, + topic: proto_collection.topic, + metadata: collection_metadata, + dimension: proto_collection.dimension, + tenant: proto_collection.tenant, + database: proto_collection.database, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_collection_try_from() { + let proto_collection = chroma_proto::Collection { + id: "00000000-0000-0000-0000-000000000000".to_string(), + name: "foo".to_string(), + topic: "bar".to_string(), + metadata: None, + dimension: None, + tenant: "baz".to_string(), + database: "qux".to_string(), + }; + let converted_collection: Collection = proto_collection.try_into().unwrap(); + assert_eq!(converted_collection.id, Uuid::nil()); + assert_eq!(converted_collection.name, "foo".to_string()); + assert_eq!(converted_collection.topic, "bar".to_string()); + assert_eq!(converted_collection.metadata, None); + assert_eq!(converted_collection.dimension, None); + assert_eq!(converted_collection.tenant, "baz".to_string()); + assert_eq!(converted_collection.database, "qux".to_string()); + } +} diff --git a/rust/worker/src/types/embedding_record.rs b/rust/worker/src/types/embedding_record.rs new file mode 100644 index 00000000000..2b4f2361e0a --- /dev/null +++ b/rust/worker/src/types/embedding_record.rs @@ -0,0 +1,229 @@ +use super::{ + ConversionError, Operation, OperationConversionError, ScalarEncoding, + ScalarEncodingConversionError, SeqId, UpdateMetadata, UpdateMetadataValueConversionError, +}; +use crate::{ + chroma_proto, + errors::{ChromaError, ErrorCodes}, +}; +use thiserror::Error; +use uuid::Uuid; + +#[derive(Debug)] +pub(crate) struct EmbeddingRecord { + pub(crate) id: String, + pub(crate) seq_id: SeqId, + pub(crate) embedding: Option>, // NOTE: we only support float32 embeddings for now + pub(crate) encoding: Option, + pub(crate) metadata: Option, + pub(crate) operation: Operation, + pub(crate) collection_id: Uuid, +} + +pub(crate) type SubmitEmbeddingRecordWithSeqId = (chroma_proto::SubmitEmbeddingRecord, SeqId); + +#[derive(Error, Debug)] +pub(crate) enum EmbeddingRecordConversionError { + #[error("Invalid UUID")] + InvalidUuid, + #[error(transparent)] + DecodeError(#[from] ConversionError), + #[error(transparent)] + OperationConversionError(#[from] OperationConversionError), + #[error(transparent)] + ScalarEncodingConversionError(#[from] ScalarEncodingConversionError), + #[error(transparent)] + UpdateMetadataValueConversionError(#[from] UpdateMetadataValueConversionError), + #[error(transparent)] + VectorConversionError(#[from] VectorConversionError), +} + +impl_base_convert_error!(EmbeddingRecordConversionError, { + EmbeddingRecordConversionError::InvalidUuid => ErrorCodes::InvalidArgument, + EmbeddingRecordConversionError::OperationConversionError(inner) => inner.code(), + EmbeddingRecordConversionError::ScalarEncodingConversionError(inner) => inner.code(), + EmbeddingRecordConversionError::UpdateMetadataValueConversionError(inner) => inner.code(), + EmbeddingRecordConversionError::VectorConversionError(inner) => inner.code(), +}); + +impl TryFrom for EmbeddingRecord { + type Error = EmbeddingRecordConversionError; + + fn try_from( + proto_submit_with_seq_id: SubmitEmbeddingRecordWithSeqId, + ) -> Result { + let proto_submit = proto_submit_with_seq_id.0; + let seq_id = proto_submit_with_seq_id.1; + let op = match proto_submit.operation.try_into() { + Ok(op) => op, + Err(e) => return Err(EmbeddingRecordConversionError::OperationConversionError(e)), + }; + + let collection_uuid = match Uuid::try_parse(&proto_submit.collection_id) { + Ok(uuid) => uuid, + Err(_) => return Err(EmbeddingRecordConversionError::InvalidUuid), + }; + + let (embedding, encoding) = match proto_submit.vector { + Some(proto_vector) => match proto_vector.try_into() { + Ok((embedding, encoding)) => (Some(embedding), Some(encoding)), + Err(e) => return Err(EmbeddingRecordConversionError::VectorConversionError(e)), + }, + // If there is no vector, there is no encoding + None => (None, None), + }; + + let metadata: Option = match proto_submit.metadata { + Some(proto_metadata) => match proto_metadata.try_into() { + Ok(metadata) => Some(metadata), + Err(e) => { + return Err( + EmbeddingRecordConversionError::UpdateMetadataValueConversionError(e), + ) + } + }, + None => None, + }; + + Ok(EmbeddingRecord { + id: proto_submit.id, + seq_id: seq_id, + embedding: embedding, + encoding: encoding, + metadata: metadata, + operation: op, + collection_id: collection_uuid, + }) + } +} + +/* +=========================================== +Vector +=========================================== +*/ +impl TryFrom for (Vec, ScalarEncoding) { + type Error = VectorConversionError; + + fn try_from(proto_vector: chroma_proto::Vector) -> Result { + let out_encoding: ScalarEncoding = match proto_vector.encoding.try_into() { + Ok(encoding) => encoding, + Err(e) => return Err(VectorConversionError::ScalarEncodingConversionError(e)), + }; + + if out_encoding != ScalarEncoding::FLOAT32 { + // We only support float32 embeddings for now + return Err(VectorConversionError::UnsupportedEncoding); + } + + let out_vector = vec_to_f32(&proto_vector.vector); + match (out_vector, out_encoding) { + (Ok(vector), encoding) => Ok((vector.to_vec(), encoding)), + _ => Err(VectorConversionError::DecodeError( + ConversionError::DecodeError, + )), + } + } +} + +#[derive(Error, Debug)] +pub(crate) enum VectorConversionError { + #[error("Invalid byte length, must be divisible by 4")] + InvalidByteLength, + #[error(transparent)] + ScalarEncodingConversionError(#[from] ScalarEncodingConversionError), + #[error("Unsupported encoding")] + UnsupportedEncoding, + #[error(transparent)] + DecodeError(#[from] ConversionError), +} + +impl_base_convert_error!(VectorConversionError, { + VectorConversionError::InvalidByteLength => ErrorCodes::InvalidArgument, + VectorConversionError::UnsupportedEncoding => ErrorCodes::InvalidArgument, + VectorConversionError::ScalarEncodingConversionError(inner) => inner.code(), +}); + +/// Converts a vector of bytes to a vector of f32s +/// # WARNING +/// - This will only work if the machine is little endian since protobufs are little endian +/// - TODO: convert to big endian if the machine is big endian +/// # Notes +/// This method internally uses unsafe code to convert the bytes to f32s +fn vec_to_f32(bytes: &[u8]) -> Result<&[f32], VectorConversionError> { + // Transmutes a vector of bytes into vector of f32s + + if bytes.len() % 4 != 0 { + return Err(VectorConversionError::InvalidByteLength); + } + + unsafe { + let (pre, mid, post) = bytes.align_to::(); + if pre.len() != 0 || post.len() != 0 { + return Err(VectorConversionError::InvalidByteLength); + } + return Ok(mid); + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use num_bigint::BigInt; + + use super::*; + use crate::{chroma_proto, types::UpdateMetadataValue}; + + fn as_byte_view(input: &[f32]) -> Vec { + unsafe { + std::slice::from_raw_parts( + input.as_ptr() as *const u8, + input.len() * std::mem::size_of::(), + ) + } + .to_vec() + } + + #[test] + fn test_embedding_record_try_from() { + let mut metadata = chroma_proto::UpdateMetadata { + metadata: HashMap::new(), + }; + metadata.metadata.insert( + "foo".to_string(), + chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)), + }, + ); + let proto_vector = chroma_proto::Vector { + vector: as_byte_view(&[1.0, 2.0, 3.0]), + encoding: chroma_proto::ScalarEncoding::Float32 as i32, + dimension: 3, + }; + let proto_submit = chroma_proto::SubmitEmbeddingRecord { + id: "00000000-0000-0000-0000-000000000000".to_string(), + vector: Some(proto_vector), + metadata: Some(metadata), + operation: chroma_proto::Operation::Add as i32, + collection_id: "00000000-0000-0000-0000-000000000000".to_string(), + }; + let converted_embedding_record: EmbeddingRecord = + EmbeddingRecord::try_from((proto_submit, BigInt::from(42))).unwrap(); + assert_eq!(converted_embedding_record.id, Uuid::nil().to_string()); + assert_eq!(converted_embedding_record.seq_id, BigInt::from(42)); + assert_eq!( + converted_embedding_record.embedding, + Some(vec![1.0, 2.0, 3.0]) + ); + assert_eq!( + converted_embedding_record.encoding, + Some(ScalarEncoding::FLOAT32) + ); + let metadata = converted_embedding_record.metadata.unwrap(); + assert_eq!(metadata.len(), 1); + assert_eq!(metadata.get("foo").unwrap(), &UpdateMetadataValue::Int(42)); + assert_eq!(converted_embedding_record.operation, Operation::Add); + assert_eq!(converted_embedding_record.collection_id, Uuid::nil()); + } +} diff --git a/rust/worker/src/types/metadata.rs b/rust/worker/src/types/metadata.rs new file mode 100644 index 00000000000..8dd37f70202 --- /dev/null +++ b/rust/worker/src/types/metadata.rs @@ -0,0 +1,229 @@ +use crate::{ + chroma_proto, + errors::{ChromaError, ErrorCodes}, +}; +use std::collections::HashMap; +use thiserror::Error; + +#[derive(Debug, PartialEq)] +pub(crate) enum UpdateMetadataValue { + Int(i32), + Float(f64), + Str(String), + None, +} + +#[derive(Error, Debug)] +pub(crate) enum UpdateMetadataValueConversionError { + #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")] + InvalidValue, +} + +impl ChromaError for UpdateMetadataValueConversionError { + fn code(&self) -> crate::errors::ErrorCodes { + match self { + UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument, + } + } +} + +impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue { + type Error = UpdateMetadataValueConversionError; + + fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result { + match &value.value { + Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => { + Ok(UpdateMetadataValue::Int(*value as i32)) + } + Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => { + Ok(UpdateMetadataValue::Float(*value)) + } + Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => { + Ok(UpdateMetadataValue::Str(value.clone())) + } + _ => Err(UpdateMetadataValueConversionError::InvalidValue), + } + } +} + +/* +=========================================== +MetadataValue +=========================================== +*/ + +#[derive(Debug, PartialEq)] +pub(crate) enum MetadataValue { + Int(i32), + Float(f64), + Str(String), +} + +#[derive(Error, Debug)] +pub(crate) enum MetadataValueConversionError { + #[error("Invalid metadata value, valid values are: Int, Float, Str")] + InvalidValue, +} + +impl ChromaError for MetadataValueConversionError { + fn code(&self) -> crate::errors::ErrorCodes { + match self { + MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument, + } + } +} + +impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue { + type Error = MetadataValueConversionError; + + fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result { + match &value.value { + Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => { + Ok(MetadataValue::Int(*value as i32)) + } + Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => { + Ok(MetadataValue::Float(*value)) + } + Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => { + Ok(MetadataValue::Str(value.clone())) + } + _ => Err(MetadataValueConversionError::InvalidValue), + } + } +} + +/* +=========================================== +UpdateMetadata +=========================================== +*/ + +pub(crate) type UpdateMetadata = HashMap; + +impl TryFrom for UpdateMetadata { + type Error = UpdateMetadataValueConversionError; + + fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result { + let mut metadata = UpdateMetadata::new(); + for (key, value) in proto_metadata.metadata.iter() { + let value = match value.try_into() { + Ok(value) => value, + Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue), + }; + metadata.insert(key.clone(), value); + } + Ok(metadata) + } +} + +/* +=========================================== +Metadata +=========================================== +*/ + +pub(crate) type Metadata = HashMap; + +impl TryFrom for Metadata { + type Error = MetadataValueConversionError; + + fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result { + let mut metadata = Metadata::new(); + for (key, value) in proto_metadata.metadata.iter() { + let maybe_value: Result = value.try_into(); + if maybe_value.is_err() { + return Err(MetadataValueConversionError::InvalidValue); + } + let value = maybe_value.unwrap(); + metadata.insert(key.clone(), value); + } + Ok(metadata) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_update_metadata_try_from() { + let mut proto_metadata = chroma_proto::UpdateMetadata { + metadata: HashMap::new(), + }; + proto_metadata.metadata.insert( + "foo".to_string(), + chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)), + }, + ); + proto_metadata.metadata.insert( + "bar".to_string(), + chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)), + }, + ); + proto_metadata.metadata.insert( + "baz".to_string(), + chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::StringValue( + "42".to_string(), + )), + }, + ); + let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap(); + assert_eq!(converted_metadata.len(), 3); + assert_eq!( + converted_metadata.get("foo").unwrap(), + &UpdateMetadataValue::Int(42) + ); + assert_eq!( + converted_metadata.get("bar").unwrap(), + &UpdateMetadataValue::Float(42.0) + ); + assert_eq!( + converted_metadata.get("baz").unwrap(), + &UpdateMetadataValue::Str("42".to_string()) + ); + } + + #[test] + fn test_metadata_try_from() { + let mut proto_metadata = chroma_proto::UpdateMetadata { + metadata: HashMap::new(), + }; + proto_metadata.metadata.insert( + "foo".to_string(), + chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)), + }, + ); + proto_metadata.metadata.insert( + "bar".to_string(), + chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)), + }, + ); + proto_metadata.metadata.insert( + "baz".to_string(), + chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::StringValue( + "42".to_string(), + )), + }, + ); + let converted_metadata: Metadata = proto_metadata.try_into().unwrap(); + assert_eq!(converted_metadata.len(), 3); + assert_eq!( + converted_metadata.get("foo").unwrap(), + &MetadataValue::Int(42) + ); + assert_eq!( + converted_metadata.get("bar").unwrap(), + &MetadataValue::Float(42.0) + ); + assert_eq!( + converted_metadata.get("baz").unwrap(), + &MetadataValue::Str("42".to_string()) + ); + } +} diff --git a/rust/worker/src/types/mod.rs b/rust/worker/src/types/mod.rs new file mode 100644 index 00000000000..edda924c42c --- /dev/null +++ b/rust/worker/src/types/mod.rs @@ -0,0 +1,19 @@ +#[macro_use] +mod types; +mod collection; +mod embedding_record; +mod metadata; +mod operation; +mod scalar_encoding; +mod segment; +mod segment_scope; + +// Re-export the types module, so that we can use it as a single import in other modules. +pub use collection::*; +pub use embedding_record::*; +pub use metadata::*; +pub use operation::*; +pub use scalar_encoding::*; +pub use segment::*; +pub use segment_scope::*; +pub use types::*; diff --git a/rust/worker/src/types/operation.rs b/rust/worker/src/types/operation.rs new file mode 100644 index 00000000000..581e5c39f8e --- /dev/null +++ b/rust/worker/src/types/operation.rs @@ -0,0 +1,73 @@ +use super::ConversionError; +use crate::{ + chroma_proto, + errors::{ChromaError, ErrorCodes}, +}; +use thiserror::Error; + +#[derive(Debug, PartialEq)] +pub(crate) enum Operation { + Add, + Update, + Upsert, + Delete, +} + +#[derive(Error, Debug)] +pub(crate) enum OperationConversionError { + #[error("Invalid operation, valid operations are: Add, Upsert, Update, Delete")] + InvalidOperation, + #[error(transparent)] + DecodeError(#[from] ConversionError), +} + +impl_base_convert_error!(OperationConversionError, { + OperationConversionError::InvalidOperation => ErrorCodes::InvalidArgument, +}); + +impl TryFrom for Operation { + type Error = OperationConversionError; + + fn try_from(op: chroma_proto::Operation) -> Result { + match op { + chroma_proto::Operation::Add => Ok(Operation::Add), + chroma_proto::Operation::Upsert => Ok(Operation::Upsert), + chroma_proto::Operation::Update => Ok(Operation::Update), + chroma_proto::Operation::Delete => Ok(Operation::Delete), + _ => Err(OperationConversionError::InvalidOperation), + } + } +} + +impl TryFrom for Operation { + type Error = OperationConversionError; + + fn try_from(op: i32) -> Result { + let maybe_op = chroma_proto::Operation::try_from(op); + match maybe_op { + Ok(op) => match op { + chroma_proto::Operation::Add => Ok(Operation::Add), + chroma_proto::Operation::Upsert => Ok(Operation::Upsert), + chroma_proto::Operation::Update => Ok(Operation::Update), + chroma_proto::Operation::Delete => Ok(Operation::Delete), + _ => Err(OperationConversionError::InvalidOperation), + }, + Err(_) => Err(OperationConversionError::DecodeError( + ConversionError::DecodeError, + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::chroma_proto; + + #[test] + fn test_operation_try_from() { + let proto_op = chroma_proto::Operation::Add; + let converted_op: Operation = proto_op.try_into().unwrap(); + assert_eq!(converted_op, Operation::Add); + } +} diff --git a/rust/worker/src/types/scalar_encoding.rs b/rust/worker/src/types/scalar_encoding.rs new file mode 100644 index 00000000000..afcaf6b2e30 --- /dev/null +++ b/rust/worker/src/types/scalar_encoding.rs @@ -0,0 +1,66 @@ +use super::ConversionError; +use crate::{ + chroma_proto, + errors::{ChromaError, ErrorCodes}, +}; +use thiserror::Error; + +#[derive(Debug, PartialEq)] +pub(crate) enum ScalarEncoding { + FLOAT32, + INT32, +} + +#[derive(Error, Debug)] +pub(crate) enum ScalarEncodingConversionError { + #[error("Invalid encoding, valid encodings are: Float32, Int32")] + InvalidEncoding, + #[error(transparent)] + DecodeError(#[from] ConversionError), +} + +impl_base_convert_error!(ScalarEncodingConversionError, { + ScalarEncodingConversionError::InvalidEncoding => ErrorCodes::InvalidArgument, +}); + +impl TryFrom for ScalarEncoding { + type Error = ScalarEncodingConversionError; + + fn try_from(encoding: chroma_proto::ScalarEncoding) -> Result { + match encoding { + chroma_proto::ScalarEncoding::Float32 => Ok(ScalarEncoding::FLOAT32), + chroma_proto::ScalarEncoding::Int32 => Ok(ScalarEncoding::INT32), + _ => Err(ScalarEncodingConversionError::InvalidEncoding), + } + } +} + +impl TryFrom for ScalarEncoding { + type Error = ScalarEncodingConversionError; + + fn try_from(encoding: i32) -> Result { + let maybe_encoding = chroma_proto::ScalarEncoding::try_from(encoding); + match maybe_encoding { + Ok(encoding) => match encoding { + chroma_proto::ScalarEncoding::Float32 => Ok(ScalarEncoding::FLOAT32), + chroma_proto::ScalarEncoding::Int32 => Ok(ScalarEncoding::INT32), + _ => Err(ScalarEncodingConversionError::InvalidEncoding), + }, + Err(_) => Err(ScalarEncodingConversionError::DecodeError( + ConversionError::DecodeError, + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_scalar_encoding_try_from() { + let proto_encoding = chroma_proto::ScalarEncoding::Float32; + let converted_encoding: ScalarEncoding = proto_encoding.try_into().unwrap(); + assert_eq!(converted_encoding, ScalarEncoding::FLOAT32); + } +} diff --git a/rust/worker/src/types/segment.rs b/rust/worker/src/types/segment.rs new file mode 100644 index 00000000000..e77c720326c --- /dev/null +++ b/rust/worker/src/types/segment.rs @@ -0,0 +1,114 @@ +use super::{Metadata, MetadataValueConversionError, SegmentScope, SegmentScopeConversionError}; +use crate::{ + chroma_proto, + errors::{ChromaError, ErrorCodes}, +}; +use thiserror::Error; +use uuid::Uuid; + +#[derive(Debug, PartialEq)] +pub(crate) struct Segment { + pub(crate) id: Uuid, + pub(crate) r#type: String, + pub(crate) scope: SegmentScope, + pub(crate) topic: Option, + pub(crate) collection: Option, + pub(crate) metadata: Option, +} + +#[derive(Error, Debug)] +pub(crate) enum SegmentConversionError { + #[error("Invalid UUID")] + InvalidUuid, + #[error(transparent)] + MetadataValueConversionError(#[from] MetadataValueConversionError), + #[error(transparent)] + SegmentScopeConversionError(#[from] SegmentScopeConversionError), +} + +impl ChromaError for SegmentConversionError { + fn code(&self) -> crate::errors::ErrorCodes { + match self { + SegmentConversionError::InvalidUuid => ErrorCodes::InvalidArgument, + SegmentConversionError::SegmentScopeConversionError(e) => e.code(), + SegmentConversionError::MetadataValueConversionError(e) => e.code(), + } + } +} + +impl TryFrom for Segment { + type Error = SegmentConversionError; + + fn try_from(proto_segment: chroma_proto::Segment) -> Result { + let segment_uuid = match Uuid::try_parse(&proto_segment.id) { + Ok(uuid) => uuid, + Err(_) => return Err(SegmentConversionError::InvalidUuid), + }; + let collection_uuid = match proto_segment.collection { + Some(collection_id) => match Uuid::try_parse(&collection_id) { + Ok(uuid) => Some(uuid), + Err(_) => return Err(SegmentConversionError::InvalidUuid), + }, + // The UUID can be none in the local version of chroma but not distributed + None => return Err(SegmentConversionError::InvalidUuid), + }; + let segment_metadata: Option = match proto_segment.metadata { + Some(proto_metadata) => match proto_metadata.try_into() { + Ok(metadata) => Some(metadata), + Err(e) => return Err(SegmentConversionError::MetadataValueConversionError(e)), + }, + None => None, + }; + let scope: SegmentScope = match proto_segment.scope.try_into() { + Ok(scope) => scope, + Err(e) => return Err(SegmentConversionError::SegmentScopeConversionError(e)), + }; + + Ok(Segment { + id: segment_uuid, + r#type: proto_segment.r#type, + scope: scope, + topic: proto_segment.topic, + collection: collection_uuid, + metadata: segment_metadata, + }) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::*; + use crate::types::MetadataValue; + + #[test] + fn test_segment_try_from() { + let mut metadata = chroma_proto::UpdateMetadata { + metadata: HashMap::new(), + }; + metadata.metadata.insert( + "foo".to_string(), + chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)), + }, + ); + let proto_segment = chroma_proto::Segment { + id: "00000000-0000-0000-0000-000000000000".to_string(), + r#type: "foo".to_string(), + scope: chroma_proto::SegmentScope::Vector as i32, + topic: Some("test".to_string()), + collection: Some("00000000-0000-0000-0000-000000000000".to_string()), + metadata: Some(metadata), + }; + let converted_segment: Segment = proto_segment.try_into().unwrap(); + assert_eq!(converted_segment.id, Uuid::nil()); + assert_eq!(converted_segment.r#type, "foo".to_string()); + assert_eq!(converted_segment.scope, SegmentScope::VECTOR); + assert_eq!(converted_segment.topic, Some("test".to_string())); + assert_eq!(converted_segment.collection, Some(Uuid::nil())); + let metadata = converted_segment.metadata.unwrap(); + assert_eq!(metadata.len(), 1); + assert_eq!(metadata.get("foo").unwrap(), &MetadataValue::Int(42)); + } +} diff --git a/rust/worker/src/types/segment_scope.rs b/rust/worker/src/types/segment_scope.rs new file mode 100644 index 00000000000..d2c1fb5392f --- /dev/null +++ b/rust/worker/src/types/segment_scope.rs @@ -0,0 +1,70 @@ +use super::ConversionError; +use crate::{ + chroma_proto, + errors::{ChromaError, ErrorCodes}, +}; +use thiserror::Error; + +#[derive(Debug, PartialEq)] +pub(crate) enum SegmentScope { + VECTOR, + METADATA, +} + +#[derive(Error, Debug)] +pub(crate) enum SegmentScopeConversionError { + #[error("Invalid segment scope, valid scopes are: Vector, Metadata")] + InvalidScope, + #[error(transparent)] + DecodeError(#[from] ConversionError), +} + +impl_base_convert_error!(SegmentScopeConversionError, { + SegmentScopeConversionError::InvalidScope => ErrorCodes::InvalidArgument, +}); + +impl TryFrom for SegmentScope { + type Error = SegmentScopeConversionError; + + fn try_from(scope: chroma_proto::SegmentScope) -> Result { + match scope { + chroma_proto::SegmentScope::Vector => Ok(SegmentScope::VECTOR), + chroma_proto::SegmentScope::Metadata => Ok(SegmentScope::METADATA), + _ => Err(SegmentScopeConversionError::InvalidScope), + } + } +} + +impl TryFrom for SegmentScope { + type Error = SegmentScopeConversionError; + + fn try_from(scope: i32) -> Result { + let maybe_scope = chroma_proto::SegmentScope::try_from(scope); + match maybe_scope { + Ok(scope) => match scope { + chroma_proto::SegmentScope::Vector => Ok(SegmentScope::VECTOR), + chroma_proto::SegmentScope::Metadata => Ok(SegmentScope::METADATA), + _ => Err(SegmentScopeConversionError::InvalidScope), + }, + Err(_) => Err(SegmentScopeConversionError::DecodeError( + ConversionError::DecodeError, + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_segment_scope_try_from() { + let proto_scope = chroma_proto::SegmentScope::Vector; + let converted_scope: SegmentScope = proto_scope.try_into().unwrap(); + assert_eq!(converted_scope, SegmentScope::VECTOR); + + let proto_scope = chroma_proto::SegmentScope::Metadata; + let converted_scope: SegmentScope = proto_scope.try_into().unwrap(); + assert_eq!(converted_scope, SegmentScope::METADATA); + } +} diff --git a/rust/worker/src/types/types.rs b/rust/worker/src/types/types.rs new file mode 100644 index 00000000000..e87337cc511 --- /dev/null +++ b/rust/worker/src/types/types.rs @@ -0,0 +1,36 @@ +use crate::errors::{ChromaError, ErrorCodes}; +use num_bigint::BigInt; +use thiserror::Error; + +/// A macro for easily implementing match arms for a base error type with common errors. +/// Other types can wrap it and still implement the ChromaError trait +/// without boilerplate. +macro_rules! impl_base_convert_error { + ($err:ty, { $($variant:pat => $action:expr),* $(,)? }) => { + impl ChromaError for $err { + fn code(&self) -> ErrorCodes { + match self { + Self::DecodeError(inner) => inner.code(), + // Handle custom variants + $( $variant => $action, )* + } + } + } + }; +} + +#[derive(Error, Debug)] +pub(crate) enum ConversionError { + #[error("Error decoding protobuf message")] + DecodeError, +} + +impl ChromaError for ConversionError { + fn code(&self) -> crate::errors::ErrorCodes { + match self { + ConversionError::DecodeError => ErrorCodes::Internal, + } + } +} + +pub(crate) type SeqId = BigInt;