diff --git a/.vscode/settings.json b/.vscode/settings.json index ccddc8d4c8c..a5def08ba63 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -128,4 +128,4 @@ "unordered_set": "cpp", "algorithm": "cpp" }, -} +} \ No newline at end of file diff --git a/rust/worker/src/compactor/scheduler.rs b/rust/worker/src/compactor/scheduler.rs index bd51ed0320b..1ec961a077f 100644 --- a/rust/worker/src/compactor/scheduler.rs +++ b/rust/worker/src/compactor/scheduler.rs @@ -281,7 +281,7 @@ mod tests { collection_id: collection_id_1.clone(), log_id: 1, log_id_ts: 1, - record: Box::new(EmbeddingRecord { + record: EmbeddingRecord { id: "embedding_id_1".to_string(), seq_id: BigInt::from(1), embedding: None, @@ -289,7 +289,7 @@ mod tests { metadata: None, operation: Operation::Add, collection_id: collection_uuid_1, - }), + }, }), ); @@ -301,7 +301,7 @@ mod tests { collection_id: collection_id_2.clone(), log_id: 2, log_id_ts: 2, - record: Box::new(EmbeddingRecord { + record: EmbeddingRecord { id: "embedding_id_2".to_string(), seq_id: BigInt::from(2), embedding: None, @@ -309,7 +309,7 @@ mod tests { metadata: None, operation: Operation::Add, collection_id: collection_uuid_2, - }), + }, }), ); diff --git a/rust/worker/src/execution/data/data_chunk.rs b/rust/worker/src/execution/data/data_chunk.rs new file mode 100644 index 00000000000..5f13d57cb2a --- /dev/null +++ b/rust/worker/src/execution/data/data_chunk.rs @@ -0,0 +1,170 @@ +use std::sync::Arc; + +use crate::types::EmbeddingRecord; + +#[derive(Clone, Debug)] +pub(crate) struct DataChunk { + data: Arc<[EmbeddingRecord]>, + visibility: Arc<[bool]>, +} + +impl DataChunk { + pub fn new(data: Arc<[EmbeddingRecord]>) -> Self { + let len = data.len(); + DataChunk { + data, + visibility: vec![true; len].into(), + } + } + + /// Returns the total length of the data chunk + pub fn total_len(&self) -> usize { + self.data.len() + } + + /// Returns the number of visible elements in the data chunk + pub fn len(&self) -> usize { + self.visibility.iter().filter(|&v| *v).count() + } + + /// Returns the element at the given index + /// if the index is out of bounds, it returns None + /// # Arguments + /// * `index` - The index of the element + pub fn get(&self, index: usize) -> Option<&EmbeddingRecord> { + if index < self.data.len() { + Some(&self.data[index]) + } else { + None + } + } + + /// Returns the visibility of the element at the given index + /// if the index is out of bounds, it returns None + /// # Arguments + /// * `index` - The index of the element + pub fn get_visibility(&self, index: usize) -> Option { + if index < self.visibility.len() { + Some(self.visibility[index]) + } else { + None + } + } + + /// Sets the visibility of the elements in the data chunk. + /// Note that the length of the visibility vector should be + /// equal to the length of the data chunk. + /// + /// Note that this is the only way to change the visibility of the elements in the data chunk, + /// the data chunk does not provide a way to change the visibility of individual elements. + /// This is to ensure that the visibility of the elements is always in sync with the data. + /// If you want to change the visibility of individual elements, you should create a new data chunk. + /// + /// # Arguments + /// * `visibility` - A vector of boolean values indicating the visibility of the elements + pub fn set_visibility(&mut self, visibility: Vec) { + self.visibility = visibility.into(); + } + + /// Returns an iterator over the visible elements in the data chunk + /// The iterator returns a tuple of the element and its index + /// # Returns + /// An iterator over the visible elements in the data chunk + pub fn iter(&self) -> DataChunkIteraror<'_> { + DataChunkIteraror { + chunk: self, + index: 0, + } + } +} + +pub(crate) struct DataChunkIteraror<'a> { + chunk: &'a DataChunk, + index: usize, +} + +impl<'a> Iterator for DataChunkIteraror<'a> { + type Item = (&'a EmbeddingRecord, usize); + + fn next(&mut self) -> Option { + while self.index < self.chunk.total_len() { + let index = self.index; + match self.chunk.get_visibility(index) { + Some(true) => { + self.index += 1; + return self.chunk.get(index).map(|record| (record, index)); + } + Some(false) => { + self.index += 1; + } + None => { + break; + } + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::EmbeddingRecord; + use crate::types::Operation; + use num_bigint::BigInt; + use std::str::FromStr; + use uuid::Uuid; + + #[test] + fn test_data_chunk() { + let collection_uuid_1 = Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(); + let data = vec![ + EmbeddingRecord { + id: "embedding_id_1".to_string(), + seq_id: BigInt::from(1), + embedding: None, + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: collection_uuid_1, + }, + EmbeddingRecord { + id: "embedding_id_2".to_string(), + seq_id: BigInt::from(2), + embedding: None, + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: collection_uuid_1, + }, + ]; + let data = data.into(); + let mut chunk = DataChunk::new(data); + assert_eq!(chunk.len(), 2); + let mut iter = chunk.iter(); + let elem = iter.next(); + assert_eq!(elem.is_some(), true); + let (record, index) = elem.unwrap(); + assert_eq!(record.id, "embedding_id_1"); + assert_eq!(index, 0); + let elem = iter.next(); + assert_eq!(elem.is_some(), true); + let (record, index) = elem.unwrap(); + assert_eq!(record.id, "embedding_id_2"); + assert_eq!(index, 1); + let elem = iter.next(); + assert_eq!(elem.is_none(), true); + + let visibility = vec![true, false].into(); + chunk.set_visibility(visibility); + assert_eq!(chunk.len(), 1); + let mut iter = chunk.iter(); + let elem = iter.next(); + assert_eq!(elem.is_some(), true); + let (record, index) = elem.unwrap(); + assert_eq!(record.id, "embedding_id_1"); + assert_eq!(index, 0); + let elem = iter.next(); + assert_eq!(elem.is_none(), true); + } +} diff --git a/rust/worker/src/execution/data/mod.rs b/rust/worker/src/execution/data/mod.rs new file mode 100644 index 00000000000..ecbe39f3445 --- /dev/null +++ b/rust/worker/src/execution/data/mod.rs @@ -0,0 +1 @@ +pub(crate) mod data_chunk; diff --git a/rust/worker/src/execution/mod.rs b/rust/worker/src/execution/mod.rs index 0000e23f3a3..1d361780d77 100644 --- a/rust/worker/src/execution/mod.rs +++ b/rust/worker/src/execution/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod config; +mod data; pub(crate) mod dispatcher; pub(crate) mod operator; mod operators; diff --git a/rust/worker/src/execution/operators/brute_force_knn.rs b/rust/worker/src/execution/operators/brute_force_knn.rs index 3a9d05c8426..13e02dc9af9 100644 --- a/rust/worker/src/execution/operators/brute_force_knn.rs +++ b/rust/worker/src/execution/operators/brute_force_knn.rs @@ -1,6 +1,10 @@ +use crate::execution::data::data_chunk::DataChunk; use crate::{distance::DistanceFunction, execution::operator::Operator}; use async_trait::async_trait; -use std::cmp; +use std::cmp::Ord; +use std::cmp::Ordering; +use std::cmp::PartialOrd; +use std::collections::BinaryHeap; /// The brute force k-nearest neighbors operator is responsible for computing the k-nearest neighbors /// of a given query vector against a set of vectors using brute force calculation. @@ -17,7 +21,7 @@ pub struct BruteForceKnnOperator {} /// * `distance_metric` - The distance metric to use. #[derive(Debug)] pub struct BruteForceKnnOperatorInput { - pub data: Vec>, + pub data: DataChunk, pub query: Vec, pub k: usize, pub distance_metric: DistanceFunction, @@ -25,37 +29,95 @@ pub struct BruteForceKnnOperatorInput { /// The output of the brute force k-nearest neighbors operator. /// # Parameters +/// * `data` - The vectors to query against. Only the vectors that are nearest neighbors are visible. /// * `indices` - The indices of the nearest neighbors. This is a mask against the `query_vecs` input. /// One row for each query vector. /// * `distances` - The distances of the nearest neighbors. /// One row for each query vector. #[derive(Debug)] pub struct BruteForceKnnOperatorOutput { + pub data: DataChunk, pub indices: Vec, pub distances: Vec, } pub type BruteForceKnnOperatorResult = Result; +#[derive(Debug)] +struct Entry { + index: usize, + distance: f32, +} + +impl Ord for Entry { + fn cmp(&self, other: &Self) -> Ordering { + if self.distance == other.distance { + Ordering::Equal + } else if self.distance > other.distance { + // This is a min heap, so we need to reverse the ordering. + Ordering::Less + } else { + // This is a min heap, so we need to reverse the ordering. + Ordering::Greater + } + } +} + +impl PartialOrd for Entry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for Entry { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} + +impl Eq for Entry {} + #[async_trait] impl Operator for BruteForceKnnOperator { type Error = (); async fn run(&self, input: &BruteForceKnnOperatorInput) -> BruteForceKnnOperatorResult { - // We could use a heap approach here, but for now we just sort the distances and take the - // first k. - let mut sorted_indices_distances = input - .data - .iter() - .map(|data| input.distance_metric.distance(&input.query, data)) - .enumerate() - .collect::>(); - sorted_indices_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); - let (sorted_indices, sorted_distances) = sorted_indices_distances - .drain(..cmp::min(input.k, input.data.len())) - .unzip(); + let mut heap = BinaryHeap::with_capacity(input.k); + let data_chunk = &input.data; + for data in data_chunk.iter() { + let embedding_record = data.0; + let index = data.1; + let embedding = match &embedding_record.embedding { + Some(embedding) => embedding, + None => { + continue; + } + }; + let distance = input.distance_metric.distance(&embedding[..], &input.query); + heap.push(Entry { index, distance }); + } + + let mut visibility = vec![false; data_chunk.total_len()]; + let mut sorted_indices = Vec::with_capacity(input.k); + let mut sorted_distances = Vec::with_capacity(input.k); + let mut i = 0; + while i < input.k { + let entry = match heap.pop() { + Some(entry) => entry, + None => { + break; + } + }; + sorted_indices.push(entry.index); + sorted_distances.push(entry.distance); + visibility[entry.index] = true; + i += 1; + } + let mut data_chunk = data_chunk.clone(); + data_chunk.set_visibility(visibility); Ok(BruteForceKnnOperatorOutput { + data: data_chunk, indices: sorted_indices, distances: sorted_distances, }) @@ -64,17 +126,49 @@ impl Operator for Brute #[cfg(test)] mod tests { + use crate::types::EmbeddingRecord; + use crate::types::Operation; + use num_bigint::BigInt; + use uuid::Uuid; + use super::*; #[tokio::test] async fn test_brute_force_knn_l2sqr() { let operator = BruteForceKnnOperator {}; + let data = vec![ + EmbeddingRecord { + id: "embedding_id_1".to_string(), + seq_id: BigInt::from(0), + embedding: Some(vec![0.0, 0.0, 0.0]), + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: Uuid::new_v4(), + }, + EmbeddingRecord { + id: "embedding_id_2".to_string(), + seq_id: BigInt::from(1), + embedding: Some(vec![0.0, 1.0, 1.0]), + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: Uuid::new_v4(), + }, + EmbeddingRecord { + id: "embedding_id_3".to_string(), + seq_id: BigInt::from(2), + embedding: Some(vec![7.0, 8.0, 9.0]), + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: Uuid::new_v4(), + }, + ]; + let data_chunk = DataChunk::new(data.into()); + let input = BruteForceKnnOperatorInput { - data: vec![ - vec![0.0, 0.0, 0.0], - vec![0.0, 1.0, 1.0], - vec![7.0, 8.0, 9.0], - ], + data: data_chunk, query: vec![0.0, 0.0, 0.0], k: 2, distance_metric: DistanceFunction::Euclidean, @@ -83,6 +177,9 @@ mod tests { assert_eq!(output.indices, vec![0, 1]); let distance_1 = 0.0_f32.powi(2) + 1.0_f32.powi(2) + 1.0_f32.powi(2); assert_eq!(output.distances, vec![0.0, distance_1]); + assert_eq!(output.data.get_visibility(0), Some(true)); + assert_eq!(output.data.get_visibility(1), Some(true)); + assert_eq!(output.data.get_visibility(2), Some(false)); } #[tokio::test] @@ -95,8 +192,39 @@ mod tests { let norm_2 = (0.0_f32.powi(2) + -1.0_f32.powi(2) + 6.0_f32.powi(2)).sqrt(); let data_2 = vec![0.0 / norm_2, -1.0 / norm_2, 6.0 / norm_2]; + let data = vec![ + EmbeddingRecord { + id: "embedding_id_1".to_string(), + seq_id: BigInt::from(0), + embedding: Some(vec![0.0, 1.0, 0.0]), + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: Uuid::new_v4(), + }, + EmbeddingRecord { + id: "embedding_id_2".to_string(), + seq_id: BigInt::from(1), + embedding: Some(data_1.clone()), + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: Uuid::new_v4(), + }, + EmbeddingRecord { + id: "embedding_id_3".to_string(), + seq_id: BigInt::from(2), + embedding: Some(data_2.clone()), + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: Uuid::new_v4(), + }, + ]; + let data_chunk = DataChunk::new(data.into()); + let input = BruteForceKnnOperatorInput { - data: vec![vec![0.0, 1.0, 0.0], data_1.clone(), data_2.clone()], + data: data_chunk, query: vec![0.0, 1.0, 0.0], k: 2, distance_metric: DistanceFunction::InnerProduct, @@ -107,14 +235,29 @@ mod tests { let expected_distance_1 = 1.0f32 - ((data_1[0] * 0.0) + (data_1[1] * 1.0) + (data_1[2] * 0.0)); assert_eq!(output.distances, vec![0.0, expected_distance_1]); + assert_eq!(output.data.get_visibility(0), Some(true)); + assert_eq!(output.data.get_visibility(1), Some(true)); + assert_eq!(output.data.get_visibility(2), Some(false)); } #[tokio::test] async fn test_data_less_than_k() { // If we have less data than k, we should return all the data, sorted by distance. let operator = BruteForceKnnOperator {}; + let data = vec![EmbeddingRecord { + id: "embedding_id_1".to_string(), + seq_id: BigInt::from(0), + embedding: Some(vec![0.0, 0.0, 0.0]), + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: Uuid::new_v4(), + }]; + + let data_chunk = DataChunk::new(data.into()); + let input = BruteForceKnnOperatorInput { - data: vec![vec![0.0, 0.0, 0.0]], + data: data_chunk, query: vec![0.0, 0.0, 0.0], k: 2, distance_metric: DistanceFunction::Euclidean, @@ -122,5 +265,6 @@ mod tests { let output = operator.run(&input).await.unwrap(); assert_eq!(output.indices, vec![0]); assert_eq!(output.distances, vec![0.0]); + assert_eq!(output.data.get_visibility(0), Some(true)); } } diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index 60198481545..ed31dca33db 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -1,3 +1,4 @@ pub(super) mod brute_force_knn; pub(super) mod normalize_vectors; +pub(super) mod partition; pub(super) mod pull_log; diff --git a/rust/worker/src/execution/operators/partition.rs b/rust/worker/src/execution/operators/partition.rs new file mode 100644 index 00000000000..e32b693ff73 --- /dev/null +++ b/rust/worker/src/execution/operators/partition.rs @@ -0,0 +1,221 @@ +use crate::errors::{ChromaError, ErrorCodes}; +use crate::execution::data::data_chunk::DataChunk; +use crate::execution::operator::Operator; +use async_trait::async_trait; +use std::collections::HashMap; +use thiserror::Error; + +#[derive(Debug)] +/// The partition Operator takes a DataChunk and presents a copy-free +/// view of N partitions by breaking the data into partitions by max_partition_size. It will group operations +/// on the same key into the same partition. Due to this, the max_partition_size is a +/// soft-limit, since if there are more operations to a key than max_partition_size we cannot +/// partition the data. +pub struct PartitionOperator {} + +/// The input to the partition operator. +/// # Parameters +/// * `records` - The records to partition. +#[derive(Debug)] +pub struct PartitionInput { + pub(crate) records: DataChunk, + pub(crate) max_partition_size: usize, +} + +impl PartitionInput { + /// Create a new partition input. + /// # Parameters + /// * `records` - The records to partition. + /// * `max_partition_size` - The maximum size of a partition. Since we are trying to + /// partition the records by id, which can casue the partition size to be larger than this + /// value. + pub fn new(records: DataChunk, max_partition_size: usize) -> Self { + PartitionInput { + records, + max_partition_size, + } + } +} + +/// The output of the partition operator. +/// # Parameters +/// * `records` - The partitioned records. +#[derive(Debug)] +pub struct PartitionOutput { + pub(crate) records: Vec, +} + +#[derive(Debug, Error)] +pub enum PartitionError { + #[error("Failed to partition records.")] + PartitionError, +} + +impl ChromaError for PartitionError { + fn code(&self) -> ErrorCodes { + match self { + PartitionError::PartitionError => ErrorCodes::Internal, + } + } +} + +pub type PartitionResult = Result; + +impl PartitionOperator { + pub fn new() -> Box { + Box::new(PartitionOperator {}) + } + + pub fn partition(&self, records: &DataChunk, partition_size: usize) -> Vec { + let mut map = HashMap::new(); + for data in records.iter() { + let record = data.0; + let index = data.1; + let key = record.id.clone(); + map.entry(key).or_insert_with(Vec::new).push(index); + } + let mut result = Vec::new(); + // Create a new DataChunk for each parition of records with partition_size without + // data copying. + let mut current_batch_size = 0; + let mut new_partition = true; + let mut visibility = vec![false; records.total_len()]; + for (_, v) in map.iter() { + // create DataChunk with partition_size by masking the visibility of the records + // in the partition. + if new_partition { + visibility = vec![false; records.total_len()]; + new_partition = false; + } + for i in v.iter() { + visibility[*i] = true; + } + current_batch_size += v.len(); + if current_batch_size >= partition_size { + let mut new_data_chunk = records.clone(); + new_data_chunk.set_visibility(visibility.clone()); + result.push(new_data_chunk); + new_partition = true; + current_batch_size = 0; + } + } + // handle the case that the last group is smaller than the group_size. + if !new_partition { + let mut new_data_chunk = records.clone(); + new_data_chunk.set_visibility(visibility.clone()); + result.push(new_data_chunk); + } + result + } + + fn determine_partition_size(&self, num_records: usize, threshold: usize) -> usize { + if num_records < threshold { + return num_records; + } else { + return threshold; + } + } +} + +#[async_trait] +impl Operator for PartitionOperator { + type Error = PartitionError; + + async fn run(&self, input: &PartitionInput) -> PartitionResult { + let records = &input.records; + let partition_size = self.determine_partition_size(records.len(), input.max_partition_size); + let deduped_records = self.partition(records, partition_size); + return Ok(PartitionOutput { + records: deduped_records, + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::EmbeddingRecord; + use crate::types::Operation; + use num_bigint::BigInt; + use std::str::FromStr; + use std::sync::Arc; + use uuid::Uuid; + + #[tokio::test] + async fn test_partition_operator() { + let collection_uuid_1 = Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(); + let collection_uuid_2 = Uuid::from_str("00000000-0000-0000-0000-000000000002").unwrap(); + let data = vec![ + EmbeddingRecord { + id: "embedding_id_1".to_string(), + seq_id: BigInt::from(1), + embedding: None, + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: collection_uuid_1, + }, + EmbeddingRecord { + id: "embedding_id_2".to_string(), + seq_id: BigInt::from(2), + embedding: None, + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: collection_uuid_1, + }, + EmbeddingRecord { + id: "embedding_id_1".to_string(), + seq_id: BigInt::from(3), + embedding: None, + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: collection_uuid_2, + }, + ]; + let data: Arc<[EmbeddingRecord]> = data.into(); + + // Test group size is larger than the number of records + let chunk = DataChunk::new(data.clone()); + let operator = PartitionOperator::new(); + let input = PartitionInput::new(chunk, 4); + let result = operator.run(&input).await.unwrap(); + assert_eq!(result.records.len(), 1); + assert_eq!(result.records[0].len(), 3); + + // Test group size is the same as the number of records + let chunk = DataChunk::new(data.clone()); + let operator = PartitionOperator::new(); + let input = PartitionInput::new(chunk, 3); + let result = operator.run(&input).await.unwrap(); + assert_eq!(result.records.len(), 1); + assert_eq!(result.records[0].len(), 3); + + // Test group size is smaller than the number of records + let chunk = DataChunk::new(data.clone()); + let operator = PartitionOperator::new(); + let input = PartitionInput::new(chunk, 2); + let mut result = operator.run(&input).await.unwrap(); + + // The result can be 1 or 2 groups depending on the order of the records. + assert!(result.records.len() == 2 || result.records.len() == 1); + if result.records.len() == 2 { + result.records.sort_by(|a, b| a.len().cmp(&b.len())); + assert_eq!(result.records[0].len(), 1); + assert_eq!(result.records[1].len(), 2); + } else { + assert_eq!(result.records[0].len(), 3); + } + + // Test group size is smaller than the number of records + let chunk = DataChunk::new(data.clone()); + let operator = PartitionOperator::new(); + let input = PartitionInput::new(chunk, 1); + let mut result = operator.run(&input).await.unwrap(); + assert_eq!(result.records.len(), 2); + result.records.sort_by(|a, b| a.len().cmp(&b.len())); + assert_eq!(result.records[0].len(), 1); + assert_eq!(result.records[1].len(), 2); + } +} diff --git a/rust/worker/src/execution/operators/pull_log.rs b/rust/worker/src/execution/operators/pull_log.rs index ad564662654..d0a9cc7ae61 100644 --- a/rust/worker/src/execution/operators/pull_log.rs +++ b/rust/worker/src/execution/operators/pull_log.rs @@ -1,8 +1,7 @@ -use crate::{ - execution::operator::Operator, - log::log::{Log, PullLogsError}, - types::EmbeddingRecord, -}; +use crate::execution::data::data_chunk::DataChunk; +use crate::execution::operator::Operator; +use crate::log::log::Log; +use crate::log::log::PullLogsError; use async_trait::async_trait; use uuid::Uuid; @@ -65,22 +64,22 @@ impl PullLogsInput { /// The output of the pull logs operator. #[derive(Debug)] pub struct PullLogsOutput { - logs: Vec>, + logs: DataChunk, } impl PullLogsOutput { /// Create a new pull logs output. /// # Parameters /// * `logs` - The logs that were read. - pub fn new(logs: Vec>) -> Self { + pub fn new(logs: DataChunk) -> Self { PullLogsOutput { logs } } /// Get the log entries that were read by an invocation of the pull logs operator. /// # Returns /// The log entries that were read. - pub fn logs(&self) -> &Vec> { - &self.logs + pub fn logs(&self) -> DataChunk { + self.logs.clone() } } @@ -139,7 +138,9 @@ impl Operator for PullLogsOperator { if input.num_records.is_some() && result.len() > input.num_records.unwrap() as usize { result.truncate(input.num_records.unwrap() as usize); } - Ok(PullLogsOutput::new(result)) + // Convert to DataChunk + let data_chunk = DataChunk::new(result.into()); + Ok(PullLogsOutput::new(data_chunk)) } } @@ -166,7 +167,7 @@ mod tests { collection_id: collection_id_1.clone(), log_id: 1, log_id_ts: 1, - record: Box::new(EmbeddingRecord { + record: EmbeddingRecord { id: "embedding_id_1".to_string(), seq_id: BigInt::from(1), embedding: None, @@ -174,7 +175,7 @@ mod tests { metadata: None, operation: Operation::Add, collection_id: collection_uuid_1, - }), + }, }), ); log.add_log( @@ -183,7 +184,7 @@ mod tests { collection_id: collection_id_1.clone(), log_id: 2, log_id_ts: 2, - record: Box::new(EmbeddingRecord { + record: EmbeddingRecord { id: "embedding_id_2".to_string(), seq_id: BigInt::from(2), embedding: None, @@ -191,7 +192,7 @@ mod tests { metadata: None, operation: Operation::Add, collection_id: collection_uuid_1, - }), + }, }), ); diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs new file mode 100644 index 00000000000..bce69aa59e5 --- /dev/null +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -0,0 +1,221 @@ +use super::super::operator::{wrap, TaskMessage}; +use crate::errors::ChromaError; +use crate::execution::data::data_chunk::DataChunk; +use crate::execution::operators::partition::PartitionInput; +use crate::execution::operators::partition::PartitionOperator; +use crate::execution::operators::partition::PartitionResult; +use crate::execution::operators::pull_log::PullLogsInput; +use crate::execution::operators::pull_log::PullLogsOperator; +use crate::execution::operators::pull_log::PullLogsResult; +use crate::log::log::Log; +use crate::sysdb::sysdb::SysDb; +use crate::system::Component; +use crate::system::Handler; +use crate::system::Receiver; +use crate::system::System; +use async_trait::async_trait; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; +use uuid::Uuid; + +/** The state of the orchestrator. +In chroma, we have a relatively fixed number of query plans that we can execute. Rather +than a flexible state machine abstraction, we just manually define the states that we +expect to encounter for a given query plan. This is a bit more rigid, but it's also simpler and easier to +understand. We can always add more abstraction later if we need it. +```plaintext + + ┌───► Write─────-------┐ + │ │ + Pending ─► PullLogs ─► Group │ ├─► Flush ─► Finished + │ │ + └───► Write ───────────┘ + +``` +*/ +#[derive(Debug)] +enum ExecutionState { + Pending, + PullLogs, + Partition, + Write, + Flush, + Finished, +} + +#[derive(Debug)] +pub struct CompactOrchestrator { + state: ExecutionState, + // Component Execution + system: System, + segment_id: Uuid, + // Dependencies + log: Box, + sysdb: Box, + // Dispatcher + dispatcher: Box>, + // Result Channel + result_channel: Option>>>, +} + +impl CompactOrchestrator { + pub fn new( + system: System, + segment_id: Uuid, + log: Box, + sysdb: Box, + dispatcher: Box>, + result_channel: Option>>>, + ) -> Self { + CompactOrchestrator { + state: ExecutionState::Pending, + system, + segment_id, + log, + sysdb, + dispatcher, + result_channel, + } + } + + /// Get the collection id for a segment id. + /// TODO: This can be cached + async fn get_collection_id_for_segment_id(&mut self, segment_id: Uuid) -> Option { + let segments = self + .sysdb + .get_segments(Some(segment_id), None, None, None, None) + .await; + match segments { + Ok(segments) => match segments.get(0) { + Some(segment) => segment.collection, + None => None, + }, + Err(e) => { + // Log an error and return + return None; + } + } + } + + async fn pull_logs(&mut self, self_address: Box>) { + self.state = ExecutionState::PullLogs; + let operator = PullLogsOperator::new(self.log.clone()); + let collection_id = match self.get_collection_id_for_segment_id(self.segment_id).await { + Some(collection_id) => collection_id, + None => { + // Log an error and reply + return + return; + } + }; + let end_timestamp = SystemTime::now().duration_since(UNIX_EPOCH); + let end_timestamp = match end_timestamp { + // TODO: change protobuf definition to use u64 instead of i64 + Ok(end_timestamp) => end_timestamp.as_secs() as i64, + Err(e) => { + // Log an error and reply + return + return; + } + }; + let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp)); + let task = wrap(operator, input, self_address); + match self.dispatcher.send(task).await { + Ok(_) => (), + Err(e) => { + // TODO: log an error and reply to caller + } + } + } + + async fn group( + &mut self, + records: DataChunk, + self_address: Box>, + ) { + self.state = ExecutionState::Partition; + // TODO: make this configurable + let max_partition_size = 100; + let operator = PartitionOperator::new(); + let input = PartitionInput::new(records, max_partition_size); + let task = wrap(operator, input, self_address); + match self.dispatcher.send(task).await { + Ok(_) => (), + Err(e) => { + // TODO: log an error and reply to caller + } + } + } + + async fn write(&mut self, records: Vec) { + self.state = ExecutionState::Write; + + for record in records { + // TODO: implement write + } + } +} + +// ============== Component Implementation ============== + +#[async_trait] +impl Component for CompactOrchestrator { + fn queue_size(&self) -> usize { + 1000 // TODO: make configurable + } + + async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { + self.pull_logs(ctx.sender.as_receiver()).await; + } +} + +// ============== Handlers ============== +#[async_trait] +impl Handler for CompactOrchestrator { + async fn handle( + &mut self, + message: PullLogsResult, + ctx: &crate::system::ComponentContext, + ) { + let records = match message { + Ok(result) => result.logs(), + Err(e) => { + // Log an error and return + let result_channel = match self.result_channel.take() { + Some(tx) => tx, + None => { + // Log an error + return; + } + }; + let _ = result_channel.send(Err(Box::new(e))); + return; + } + }; + self.group(records, ctx.sender.as_receiver()).await; + } +} + +#[async_trait] +impl Handler for CompactOrchestrator { + async fn handle( + &mut self, + message: PartitionResult, + ctx: &crate::system::ComponentContext, + ) { + let records = match message { + Ok(result) => result.records, + Err(e) => { + // Log an error and return + let result_channel = match self.result_channel.take() { + Some(tx) => tx, + None => { + // Log an error + return; + } + }; + let _ = result_channel.send(Err(Box::new(e))); + return; + } + }; + // TODO: implement write records + } +} diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 699bfde4b3f..579497cc03b 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -1,11 +1,9 @@ use super::super::operator::{wrap, TaskMessage}; -use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput}; -use crate::distance; +use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator}; use crate::distance::DistanceFunction; use crate::errors::ChromaError; use crate::execution::operators::brute_force_knn::{ - BruteForceKnnOperator, BruteForceKnnOperatorInput, BruteForceKnnOperatorOutput, - BruteForceKnnOperatorResult, + BruteForceKnnOperator, BruteForceKnnOperatorInput, BruteForceKnnOperatorResult, }; use crate::execution::operators::pull_log::PullLogsResult; use crate::sysdb::sysdb::SysDb; @@ -30,7 +28,7 @@ understand. We can always add more abstraction later if we need it. ┌───► Brute Force ─────┐ │ │ - Pending ─► PullLogs ─► Dedupe│ ├─► MergeResults ─► Finished + Pending ─► PullLogs ─► Group│ ├─► MergeResults ─► Finished │ │ └───► HNSW ────────────┘ @@ -40,7 +38,7 @@ understand. We can always add more abstraction later if we need it. enum ExecutionState { Pending, PullLogs, - Dedupe, + Partition, QueryKnn, MergeResults, Finished, @@ -142,7 +140,7 @@ impl HnswQueryOrchestrator { /// Run the orchestrator and return the result. /// # Note /// Use this over spawning the component directly. This method will start the component and - /// wait for it to finish before returning the result. + /// wait for it to finish before returning the result. pub(crate) async fn run(mut self) -> Result>, Box> { let (tx, rx) = tokio::sync::oneshot::channel(); self.result_channel = Some(tx); @@ -175,19 +173,14 @@ impl Handler for HnswQueryOrchestrator { message: PullLogsResult, ctx: &crate::system::ComponentContext, ) { - self.state = ExecutionState::Dedupe; + self.state = ExecutionState::Partition; // TODO: implement the remaining state transitions and operators // TODO: don't need all this cloning and data shuffling, once we land the chunk abstraction - let mut dataset = Vec::new(); match message { Ok(logs) => { - for log in logs.logs().iter() { - // TODO: only adds have embeddings, unwrap is fine for now - dataset.push(log.embedding.clone().unwrap()); - } let bf_input = BruteForceKnnOperatorInput { - data: dataset, + data: logs.logs(), query: self.query_vectors[0].clone(), k: self.k as usize, distance_metric: DistanceFunction::Euclidean, diff --git a/rust/worker/src/execution/orchestration/mod.rs b/rust/worker/src/execution/orchestration/mod.rs index 902c3eaf84d..2828cfd365a 100644 --- a/rust/worker/src/execution/orchestration/mod.rs +++ b/rust/worker/src/execution/orchestration/mod.rs @@ -1,3 +1,3 @@ +mod compact; mod hnsw; - pub(crate) use hnsw::*; diff --git a/rust/worker/src/log/log.rs b/rust/worker/src/log/log.rs index 5dae688e6d3..0f1c5c6c16c 100644 --- a/rust/worker/src/log/log.rs +++ b/rust/worker/src/log/log.rs @@ -38,7 +38,7 @@ pub(crate) trait Log: Send + Sync + LogClone + Debug { offset: i64, batch_size: i32, end_timestamp: Option, - ) -> Result>, PullLogsError>; + ) -> Result, PullLogsError>; async fn get_collections_with_new_data( &mut self, @@ -121,7 +121,7 @@ impl Log for GrpcLog { offset: i64, batch_size: i32, end_timestamp: Option, - ) -> Result>, PullLogsError> { + ) -> Result, PullLogsError> { let end_timestamp = match end_timestamp { Some(end_timestamp) => end_timestamp, None => -1, @@ -227,7 +227,7 @@ pub(crate) struct LogRecord { pub(crate) collection_id: String, pub(crate) log_id: i64, pub(crate) log_id_ts: i64, - pub(crate) record: Box, + pub(crate) record: EmbeddingRecord, } impl Debug for LogRecord { @@ -268,7 +268,7 @@ impl Log for InMemoryLog { offset: i64, batch_size: i32, end_timestamp: Option, - ) -> Result>, PullLogsError> { + ) -> Result, PullLogsError> { let end_timestamp = match end_timestamp { Some(end_timestamp) => end_timestamp, None => i64::MAX, diff --git a/rust/worker/src/types/embedding_record.rs b/rust/worker/src/types/embedding_record.rs index cc53631d4bf..6ba9bdf255a 100644 --- a/rust/worker/src/types/embedding_record.rs +++ b/rust/worker/src/types/embedding_record.rs @@ -101,7 +101,7 @@ impl TryFrom for EmbeddingRecord { } } -impl TryFrom for Box { +impl TryFrom for EmbeddingRecord { type Error = EmbeddingRecordConversionError; fn try_from(record_log: RecordLog) -> Result { @@ -143,7 +143,7 @@ impl TryFrom for Box { None => None, }; - Ok(Box::new(EmbeddingRecord { + Ok(EmbeddingRecord { id: proto_submit.id, seq_id: seq_id, embedding: embedding, @@ -151,7 +151,7 @@ impl TryFrom for Box { metadata: metadata, operation: op, collection_id: collection_uuid, - })) + }) } } @@ -364,7 +364,7 @@ mod tests { log_id: 42, record: Some(proto_submit), }; - let converted_embedding_record = Box::::try_from(record_log).unwrap(); + let converted_embedding_record = EmbeddingRecord::try_from(record_log).unwrap(); assert_eq!(converted_embedding_record.id, Uuid::nil().to_string()); assert_eq!(converted_embedding_record.seq_id, BigInt::from(42)); assert_eq!(