From c7c4dda331bfd7ee9279839e54514f444248372e Mon Sep 17 00:00:00 2001 From: Macronova <60079945+Sicheng-Pan@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:51:47 -0800 Subject: [PATCH] [ENH] Deprecate FetchSegmentOperator (#3261) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - N/A - New functionality - Deprecate `FetchSegmentOperator` since the query node already receives the full collection information from the frontend. ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* N/A --- chromadb/proto/convert.py | 3 - chromadb/proto/query_executor_pb2.py | 64 ++--- chromadb/proto/query_executor_pb2.pyi | 10 +- idl/chromadb/proto/query_executor.proto | 8 +- rust/types/src/collection.rs | 39 ++- rust/types/src/types.rs | 2 +- .../src/execution/operators/fetch_segment.rs | 137 ---------- rust/worker/src/execution/operators/mod.rs | 1 - .../execution/operators/prefetch_record.rs | 17 +- .../src/execution/operators/projection.rs | 2 +- .../worker/src/execution/orchestration/get.rs | 248 ++++++------------ .../worker/src/execution/orchestration/knn.rs | 99 +++---- .../src/execution/orchestration/knn_filter.rs | 162 +++++------- rust/worker/src/server.rs | 163 ++++-------- 14 files changed, 307 insertions(+), 648 deletions(-) delete mode 100644 rust/worker/src/execution/operators/fetch_segment.rs diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 2a18475cf98..51d30bc608a 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -573,9 +573,6 @@ def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDoc def to_proto_scan(scan: Scan) -> query_pb.ScanOperator: return query_pb.ScanOperator( collection=to_proto_collection(scan.collection), - knn_id=scan.knn["id"].hex, - metadata_id=scan.metadata["id"].hex, - record_id=scan.record["id"].hex, knn=to_proto_segment(scan.knn), metadata=to_proto_segment(scan.metadata), record=to_proto_segment(scan.record), diff --git a/chromadb/proto/query_executor_pb2.py b/chromadb/proto/query_executor_pb2.py index 8131f121198..80c322489d3 100644 --- a/chromadb/proto/query_executor_pb2.py +++ b/chromadb/proto/query_executor_pb2.py @@ -14,7 +14,7 @@ from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#chromadb/proto/query_executor.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\"\xd0\x01\n\x0cScanOperator\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0e\n\x06knn_id\x18\x02 \x01(\t\x12\x13\n\x0bmetadata_id\x18\x03 \x01(\t\x12\x11\n\trecord_id\x18\x04 \x01(\t\x12\x1c\n\x03knn\x18\x05 \x01(\x0b\x32\x0f.chroma.Segment\x12!\n\x08metadata\x18\x06 \x01(\x0b\x32\x0f.chroma.Segment\x12\x1f\n\x06record\x18\x07 \x01(\x0b\x32\x0f.chroma.Segment\"\xaf\x01\n\x0e\x46ilterOperator\x12!\n\x03ids\x18\x01 \x01(\x0b\x32\x0f.chroma.UserIdsH\x00\x88\x01\x01\x12!\n\x05where\x18\x02 \x01(\x0b\x32\r.chroma.WhereH\x01\x88\x01\x01\x12\x32\n\x0ewhere_document\x18\x03 \x01(\x0b\x32\x15.chroma.WhereDocumentH\x02\x88\x01\x01\x42\x06\n\x04_idsB\x08\n\x06_whereB\x11\n\x0f_where_document\"@\n\x0bKNNOperator\x12\"\n\nembeddings\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\r\n\x05\x66\x65tch\x18\x02 \x01(\r\";\n\rLimitOperator\x12\x0c\n\x04skip\x18\x01 \x01(\r\x12\x12\n\x05\x66\x65tch\x18\x02 \x01(\rH\x00\x88\x01\x01\x42\x08\n\x06_fetch\"K\n\x12ProjectionOperator\x12\x10\n\x08\x64ocument\x18\x01 \x01(\x08\x12\x11\n\tembedding\x18\x02 \x01(\x08\x12\x10\n\x08metadata\x18\x03 \x01(\x08\"Y\n\x15KNNProjectionOperator\x12.\n\nprojection\x18\x01 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\x12\x10\n\x08\x64istance\x18\x02 \x01(\x08\"/\n\tCountPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\"\x1c\n\x0b\x43ountResult\x12\r\n\x05\x63ount\x18\x01 \x01(\r\"\xab\x01\n\x07GetPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12$\n\x05limit\x18\x03 \x01(\x0b\x32\x15.chroma.LimitOperator\x12.\n\nprojection\x18\x04 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\"\xb4\x01\n\x10ProjectionRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x08\x64ocument\x18\x02 \x01(\tH\x00\x88\x01\x01\x12&\n\tembedding\x18\x03 \x01(\x0b\x32\x0e.chroma.VectorH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x0b\n\t_documentB\x0c\n\n_embeddingB\x0b\n\t_metadata\"6\n\tGetResult\x12)\n\x07records\x18\x01 \x03(\x0b\x32\x18.chroma.ProjectionRecord\"\xaa\x01\n\x07KNNPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12 \n\x03knn\x18\x03 \x01(\x0b\x32\x13.chroma.KNNOperator\x12\x31\n\nprojection\x18\x04 \x01(\x0b\x32\x1d.chroma.KNNProjectionOperator\"c\n\x13KNNProjectionRecord\x12(\n\x06record\x18\x01 \x01(\x0b\x32\x18.chroma.ProjectionRecord\x12\x15\n\x08\x64istance\x18\x02 \x01(\x02H\x00\x88\x01\x01\x42\x0b\n\t_distance\"9\n\tKNNResult\x12,\n\x07records\x18\x01 \x03(\x0b\x32\x1b.chroma.KNNProjectionRecord\"4\n\x0eKNNBatchResult\x12\"\n\x07results\x18\x01 \x03(\x0b\x32\x11.chroma.KNNResult2\xa1\x01\n\rQueryExecutor\x12\x31\n\x05\x43ount\x12\x11.chroma.CountPlan\x1a\x13.chroma.CountResult\"\x00\x12+\n\x03Get\x12\x0f.chroma.GetPlan\x1a\x11.chroma.GetResult\"\x00\x12\x30\n\x03KNN\x12\x0f.chroma.KNNPlan\x1a\x16.chroma.KNNBatchResult\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#chromadb/proto/query_executor.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\"\xaa\x01\n\x0cScanOperator\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1c\n\x03knn\x18\x05 \x01(\x0b\x32\x0f.chroma.Segment\x12!\n\x08metadata\x18\x06 \x01(\x0b\x32\x0f.chroma.Segment\x12\x1f\n\x06record\x18\x07 \x01(\x0b\x32\x0f.chroma.SegmentJ\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05\"\xaf\x01\n\x0e\x46ilterOperator\x12!\n\x03ids\x18\x01 \x01(\x0b\x32\x0f.chroma.UserIdsH\x00\x88\x01\x01\x12!\n\x05where\x18\x02 \x01(\x0b\x32\r.chroma.WhereH\x01\x88\x01\x01\x12\x32\n\x0ewhere_document\x18\x03 \x01(\x0b\x32\x15.chroma.WhereDocumentH\x02\x88\x01\x01\x42\x06\n\x04_idsB\x08\n\x06_whereB\x11\n\x0f_where_document\"@\n\x0bKNNOperator\x12\"\n\nembeddings\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\r\n\x05\x66\x65tch\x18\x02 \x01(\r\";\n\rLimitOperator\x12\x0c\n\x04skip\x18\x01 \x01(\r\x12\x12\n\x05\x66\x65tch\x18\x02 \x01(\rH\x00\x88\x01\x01\x42\x08\n\x06_fetch\"K\n\x12ProjectionOperator\x12\x10\n\x08\x64ocument\x18\x01 \x01(\x08\x12\x11\n\tembedding\x18\x02 \x01(\x08\x12\x10\n\x08metadata\x18\x03 \x01(\x08\"Y\n\x15KNNProjectionOperator\x12.\n\nprojection\x18\x01 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\x12\x10\n\x08\x64istance\x18\x02 \x01(\x08\"/\n\tCountPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\"\x1c\n\x0b\x43ountResult\x12\r\n\x05\x63ount\x18\x01 \x01(\r\"\xab\x01\n\x07GetPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12$\n\x05limit\x18\x03 \x01(\x0b\x32\x15.chroma.LimitOperator\x12.\n\nprojection\x18\x04 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\"\xb4\x01\n\x10ProjectionRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x08\x64ocument\x18\x02 \x01(\tH\x00\x88\x01\x01\x12&\n\tembedding\x18\x03 \x01(\x0b\x32\x0e.chroma.VectorH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x0b\n\t_documentB\x0c\n\n_embeddingB\x0b\n\t_metadata\"6\n\tGetResult\x12)\n\x07records\x18\x01 \x03(\x0b\x32\x18.chroma.ProjectionRecord\"\xaa\x01\n\x07KNNPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12 \n\x03knn\x18\x03 \x01(\x0b\x32\x13.chroma.KNNOperator\x12\x31\n\nprojection\x18\x04 \x01(\x0b\x32\x1d.chroma.KNNProjectionOperator\"c\n\x13KNNProjectionRecord\x12(\n\x06record\x18\x01 \x01(\x0b\x32\x18.chroma.ProjectionRecord\x12\x15\n\x08\x64istance\x18\x02 \x01(\x02H\x00\x88\x01\x01\x42\x0b\n\t_distance\"9\n\tKNNResult\x12,\n\x07records\x18\x01 \x03(\x0b\x32\x1b.chroma.KNNProjectionRecord\"4\n\x0eKNNBatchResult\x12\"\n\x07results\x18\x01 \x03(\x0b\x32\x11.chroma.KNNResult2\xa1\x01\n\rQueryExecutor\x12\x31\n\x05\x43ount\x12\x11.chroma.CountPlan\x1a\x13.chroma.CountResult\"\x00\x12+\n\x03Get\x12\x0f.chroma.GetPlan\x1a\x11.chroma.GetResult\"\x00\x12\x30\n\x03KNN\x12\x0f.chroma.KNNPlan\x1a\x16.chroma.KNNBatchResult\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -23,35 +23,35 @@ DESCRIPTOR._options = None _globals['_SCANOPERATOR']._serialized_start=77 - _globals['_SCANOPERATOR']._serialized_end=285 - _globals['_FILTEROPERATOR']._serialized_start=288 - _globals['_FILTEROPERATOR']._serialized_end=463 - _globals['_KNNOPERATOR']._serialized_start=465 - _globals['_KNNOPERATOR']._serialized_end=529 - _globals['_LIMITOPERATOR']._serialized_start=531 - _globals['_LIMITOPERATOR']._serialized_end=590 - _globals['_PROJECTIONOPERATOR']._serialized_start=592 - _globals['_PROJECTIONOPERATOR']._serialized_end=667 - _globals['_KNNPROJECTIONOPERATOR']._serialized_start=669 - _globals['_KNNPROJECTIONOPERATOR']._serialized_end=758 - _globals['_COUNTPLAN']._serialized_start=760 - _globals['_COUNTPLAN']._serialized_end=807 - _globals['_COUNTRESULT']._serialized_start=809 - _globals['_COUNTRESULT']._serialized_end=837 - _globals['_GETPLAN']._serialized_start=840 - _globals['_GETPLAN']._serialized_end=1011 - _globals['_PROJECTIONRECORD']._serialized_start=1014 - _globals['_PROJECTIONRECORD']._serialized_end=1194 - _globals['_GETRESULT']._serialized_start=1196 - _globals['_GETRESULT']._serialized_end=1250 - _globals['_KNNPLAN']._serialized_start=1253 - _globals['_KNNPLAN']._serialized_end=1423 - _globals['_KNNPROJECTIONRECORD']._serialized_start=1425 - _globals['_KNNPROJECTIONRECORD']._serialized_end=1524 - _globals['_KNNRESULT']._serialized_start=1526 - _globals['_KNNRESULT']._serialized_end=1583 - _globals['_KNNBATCHRESULT']._serialized_start=1585 - _globals['_KNNBATCHRESULT']._serialized_end=1637 - _globals['_QUERYEXECUTOR']._serialized_start=1640 - _globals['_QUERYEXECUTOR']._serialized_end=1801 + _globals['_SCANOPERATOR']._serialized_end=247 + _globals['_FILTEROPERATOR']._serialized_start=250 + _globals['_FILTEROPERATOR']._serialized_end=425 + _globals['_KNNOPERATOR']._serialized_start=427 + _globals['_KNNOPERATOR']._serialized_end=491 + _globals['_LIMITOPERATOR']._serialized_start=493 + _globals['_LIMITOPERATOR']._serialized_end=552 + _globals['_PROJECTIONOPERATOR']._serialized_start=554 + _globals['_PROJECTIONOPERATOR']._serialized_end=629 + _globals['_KNNPROJECTIONOPERATOR']._serialized_start=631 + _globals['_KNNPROJECTIONOPERATOR']._serialized_end=720 + _globals['_COUNTPLAN']._serialized_start=722 + _globals['_COUNTPLAN']._serialized_end=769 + _globals['_COUNTRESULT']._serialized_start=771 + _globals['_COUNTRESULT']._serialized_end=799 + _globals['_GETPLAN']._serialized_start=802 + _globals['_GETPLAN']._serialized_end=973 + _globals['_PROJECTIONRECORD']._serialized_start=976 + _globals['_PROJECTIONRECORD']._serialized_end=1156 + _globals['_GETRESULT']._serialized_start=1158 + _globals['_GETRESULT']._serialized_end=1212 + _globals['_KNNPLAN']._serialized_start=1215 + _globals['_KNNPLAN']._serialized_end=1385 + _globals['_KNNPROJECTIONRECORD']._serialized_start=1387 + _globals['_KNNPROJECTIONRECORD']._serialized_end=1486 + _globals['_KNNRESULT']._serialized_start=1488 + _globals['_KNNRESULT']._serialized_end=1545 + _globals['_KNNBATCHRESULT']._serialized_start=1547 + _globals['_KNNBATCHRESULT']._serialized_end=1599 + _globals['_QUERYEXECUTOR']._serialized_start=1602 + _globals['_QUERYEXECUTOR']._serialized_end=1763 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/query_executor_pb2.pyi b/chromadb/proto/query_executor_pb2.pyi index 33b8bb63e1a..aa800835e85 100644 --- a/chromadb/proto/query_executor_pb2.pyi +++ b/chromadb/proto/query_executor_pb2.pyi @@ -7,22 +7,16 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor class ScanOperator(_message.Message): - __slots__ = ["collection", "knn_id", "metadata_id", "record_id", "knn", "metadata", "record"] + __slots__ = ["collection", "knn", "metadata", "record"] COLLECTION_FIELD_NUMBER: _ClassVar[int] - KNN_ID_FIELD_NUMBER: _ClassVar[int] - METADATA_ID_FIELD_NUMBER: _ClassVar[int] - RECORD_ID_FIELD_NUMBER: _ClassVar[int] KNN_FIELD_NUMBER: _ClassVar[int] METADATA_FIELD_NUMBER: _ClassVar[int] RECORD_FIELD_NUMBER: _ClassVar[int] collection: _chroma_pb2.Collection - knn_id: str - metadata_id: str - record_id: str knn: _chroma_pb2.Segment metadata: _chroma_pb2.Segment record: _chroma_pb2.Segment - def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., knn_id: _Optional[str] = ..., metadata_id: _Optional[str] = ..., record_id: _Optional[str] = ..., knn: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ..., metadata: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ..., record: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ...) -> None: ... + def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., knn: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ..., metadata: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ..., record: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ...) -> None: ... class FilterOperator(_message.Message): __slots__ = ["ids", "where", "where_document"] diff --git a/idl/chromadb/proto/query_executor.proto b/idl/chromadb/proto/query_executor.proto index 53aa37a8b52..f434070c9cd 100644 --- a/idl/chromadb/proto/query_executor.proto +++ b/idl/chromadb/proto/query_executor.proto @@ -6,12 +6,8 @@ import "chromadb/proto/chroma.proto"; message ScanOperator { Collection collection = 1; - // Deprecated - string knn_id = 2; - // Deprecated - string metadata_id = 3; - // Deprecated - string record_id = 4; + // Reserve for deprecated fields + reserved 2, 3, 4; Segment knn = 5; Segment metadata = 6; Segment record = 7; diff --git a/rust/types/src/collection.rs b/rust/types/src/collection.rs index 23e55a0a588..9c7bd8914f2 100644 --- a/rust/types/src/collection.rs +++ b/rust/types/src/collection.rs @@ -1,5 +1,5 @@ use super::{Metadata, MetadataValueConversionError}; -use crate::chroma_proto; +use crate::{chroma_proto, ConversionError, Segment}; use chroma_error::{ChromaError, ErrorCodes}; use thiserror::Error; use uuid::Uuid; @@ -89,6 +89,43 @@ impl TryFrom for Collection { } } +#[derive(Clone, Debug)] +pub struct CollectionAndSegments { + pub collection: Collection, + pub metadata_segment: Segment, + pub record_segment: Segment, + pub vector_segment: Segment, +} + +impl TryFrom for CollectionAndSegments { + type Error = ConversionError; + + fn try_from(value: chroma_proto::ScanOperator) -> Result { + Ok(Self { + collection: value + .collection + .ok_or(ConversionError::DecodeError)? + .try_into() + .map_err(|_| ConversionError::DecodeError)?, + metadata_segment: value + .metadata + .ok_or(ConversionError::DecodeError)? + .try_into() + .map_err(|_| ConversionError::DecodeError)?, + record_segment: value + .record + .ok_or(ConversionError::DecodeError)? + .try_into() + .map_err(|_| ConversionError::DecodeError)?, + vector_segment: value + .knn + .ok_or(ConversionError::DecodeError)? + .try_into() + .map_err(|_| ConversionError::DecodeError)?, + }) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/rust/types/src/types.rs b/rust/types/src/types.rs index ddd07d91b89..cc81189f5b6 100644 --- a/rust/types/src/types.rs +++ b/rust/types/src/types.rs @@ -28,7 +28,7 @@ pub enum ConversionError { impl ChromaError for ConversionError { fn code(&self) -> ErrorCodes { match self { - ConversionError::DecodeError => ErrorCodes::Internal, + ConversionError::DecodeError => ErrorCodes::InvalidArgument, } } } diff --git a/rust/worker/src/execution/operators/fetch_segment.rs b/rust/worker/src/execution/operators/fetch_segment.rs deleted file mode 100644 index d7b49f17e53..00000000000 --- a/rust/worker/src/execution/operators/fetch_segment.rs +++ /dev/null @@ -1,137 +0,0 @@ -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Collection, CollectionUuid, Segment, SegmentScope, SegmentType, SegmentUuid}; -use thiserror::Error; -use tonic::async_trait; -use tracing::trace; - -use crate::{ - execution::operator::{Operator, OperatorType}, - sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb}, -}; - -/// The `FetchSegmentOperator` fetches collection and segment information from SysDB -/// -/// # Parameters -/// - `sysdb`: The SysDB reader -/// - `*_uuid`: The uuids of the collection and segments -/// - `collection_version`: The version of the collection to verify against -/// -/// # Inputs -/// - No input is required -/// -/// # Outputs -/// - `collection`: The collection information -/// - `*_segment`: The segment information -/// -/// # Usage -/// It should be run at the start of an orchestrator to get the latest data of a collection -#[derive(Clone, Debug)] -pub struct FetchSegmentOperator { - pub(crate) sysdb: Box, - pub collection_uuid: CollectionUuid, - pub collection_version: u32, - pub metadata_uuid: SegmentUuid, - pub record_uuid: SegmentUuid, - pub vector_uuid: SegmentUuid, -} - -type FetchSegmentInput = (); - -#[derive(Clone, Debug)] -pub struct FetchSegmentOutput { - pub collection: Collection, - pub metadata_segment: Segment, - pub record_segment: Segment, - pub vector_segment: Segment, -} - -#[derive(Error, Debug)] -pub enum FetchSegmentError { - #[error("Error when getting collection: {0}")] - GetCollection(#[from] GetCollectionsError), - #[error("Error when getting segment: {0}")] - GetSegment(#[from] GetSegmentsError), - #[error("No collection found")] - NoCollection, - #[error("No segment found")] - NoSegment, - // The frontend relies on ths content of the error message here to detect version mismatch - // TODO: Refactor frontend to properly detect version mismatch - #[error("Collection version mismatch")] - VersionMismatch, -} - -impl ChromaError for FetchSegmentError { - fn code(&self) -> ErrorCodes { - match self { - FetchSegmentError::GetCollection(e) => e.code(), - FetchSegmentError::GetSegment(e) => e.code(), - FetchSegmentError::NoCollection => ErrorCodes::NotFound, - FetchSegmentError::NoSegment => ErrorCodes::NotFound, - FetchSegmentError::VersionMismatch => ErrorCodes::VersionMismatch, - } - } -} - -impl FetchSegmentOperator { - async fn get_collection(&self) -> Result { - let collection = self - .sysdb - .clone() - .get_collections(Some(self.collection_uuid), None, None, None) - .await? - .pop() - .ok_or(FetchSegmentError::NoCollection)?; - if collection.version != self.collection_version as i32 { - Err(FetchSegmentError::VersionMismatch) - } else { - Ok(collection) - } - } - async fn get_segment(&self, scope: SegmentScope) -> Result { - let segment_type = match scope { - SegmentScope::METADATA => SegmentType::BlockfileMetadata, - SegmentScope::RECORD => SegmentType::BlockfileRecord, - SegmentScope::SQLITE => unimplemented!("Unexpected Sqlite segment"), - SegmentScope::VECTOR => SegmentType::HnswDistributed, - }; - let segment_id = match scope { - SegmentScope::METADATA => self.metadata_uuid, - SegmentScope::RECORD => self.record_uuid, - SegmentScope::SQLITE => unimplemented!("Unexpected Sqlite segment"), - SegmentScope::VECTOR => self.vector_uuid, - }; - self.sysdb - .clone() - .get_segments( - Some(segment_id), - Some(segment_type.into()), - Some(scope), - self.collection_uuid, - ) - .await? - // Each scope should have a single segment - .pop() - .ok_or(FetchSegmentError::NoSegment) - } -} - -#[async_trait] -impl Operator for FetchSegmentOperator { - type Error = FetchSegmentError; - - fn get_type(&self) -> OperatorType { - OperatorType::IO - } - - async fn run(&self, _: &FetchSegmentInput) -> Result { - trace!("[{}]: {:?}", self.get_name(), self); - - Ok(FetchSegmentOutput { - collection: self.get_collection().await?, - metadata_segment: self.get_segment(SegmentScope::METADATA).await?, - record_segment: self.get_segment(SegmentScope::RECORD).await?, - vector_segment: self.get_segment(SegmentScope::VECTOR).await?, - }) - } -} diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index f0e02207c65..76eccedaab9 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -10,7 +10,6 @@ pub(super) mod write_segments; // Required for benchmark pub mod fetch_log; -pub mod fetch_segment; pub mod filter; pub mod knn; pub mod knn_hnsw; diff --git a/rust/worker/src/execution/operators/prefetch_record.rs b/rust/worker/src/execution/operators/prefetch_record.rs index b97594d778b..6b11fabe458 100644 --- a/rust/worker/src/execution/operators/prefetch_record.rs +++ b/rust/worker/src/execution/operators/prefetch_record.rs @@ -1,8 +1,6 @@ use std::collections::HashSet; -use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Chunk, LogRecord, Segment}; use thiserror::Error; use tonic::async_trait; use tracing::{trace, Instrument, Span}; @@ -16,16 +14,15 @@ use crate::{ }, }; +use super::projection::ProjectionInput; + /// The `PrefetchRecordOperator` prefetches the relevant records from the record segments to the cache /// /// # Parameters /// None /// /// # Inputs -/// - `logs`: The latest logs of the collection -/// - `blockfile_provider`: The blockfile provider -/// - `record_segment`: The record segment information -/// - `offset_ids`: The offset ids of the records to prefetch +/// Identical to ProjectionInput /// /// # Outputs /// None @@ -35,13 +32,7 @@ use crate::{ #[derive(Debug)] pub struct PrefetchRecordOperator {} -#[derive(Debug)] -pub struct PrefetchRecordInput { - pub logs: Chunk, - pub blockfile_provider: BlockfileProvider, - pub record_segment: Segment, - pub offset_ids: Vec, -} +pub type PrefetchRecordInput = ProjectionInput; pub type PrefetchRecordOutput = (); diff --git a/rust/worker/src/execution/operators/projection.rs b/rust/worker/src/execution/operators/projection.rs index f02f89acbfe..39555af4ead 100644 --- a/rust/worker/src/execution/operators/projection.rs +++ b/rust/worker/src/execution/operators/projection.rs @@ -42,7 +42,7 @@ pub struct ProjectionOperator { pub metadata: bool, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ProjectionInput { pub logs: Chunk, pub blockfile_provider: BlockfileProvider, diff --git a/rust/worker/src/execution/orchestration/get.rs b/rust/worker/src/execution/orchestration/get.rs index 443b5be1b59..9eaf2995e8a 100644 --- a/rust/worker/src/execution/orchestration/get.rs +++ b/rust/worker/src/execution/orchestration/get.rs @@ -1,5 +1,6 @@ use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; +use chroma_types::CollectionAndSegments; use thiserror::Error; use tokio::sync::oneshot::{self, error::RecvError, Sender}; use tonic::async_trait; @@ -11,13 +12,9 @@ use crate::{ operator::{wrap, TaskError, TaskResult}, operators::{ fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}, - fetch_segment::{FetchSegmentError, FetchSegmentOperator, FetchSegmentOutput}, filter::{FilterError, FilterInput, FilterOperator, FilterOutput}, limit::{LimitError, LimitInput, LimitOperator, LimitOutput}, - prefetch_record::{ - PrefetchRecordError, PrefetchRecordInput, PrefetchRecordOperator, - PrefetchRecordOutput, - }, + prefetch_record::{PrefetchRecordError, PrefetchRecordOperator, PrefetchRecordOutput}, projection::{ProjectionError, ProjectionInput, ProjectionOperator, ProjectionOutput}, }, orchestration::common::terminate_with_error, @@ -31,8 +28,6 @@ pub enum GetError { Channel(#[from] ChannelError), #[error("Error running Fetch Log Operator: {0}")] FetchLog(#[from] FetchLogError), - #[error("Error running Fetch Segment Operator: {0}")] - FetchSegment(#[from] FetchSegmentError), #[error("Error running Filter Operator: {0}")] Filter(#[from] FilterError), #[error("Error running Limit Operator: {0}")] @@ -50,7 +45,6 @@ impl ChromaError for GetError { match self { GetError::Channel(e) => e.code(), GetError::FetchLog(e) => e.code(), - GetError::FetchSegment(e) => e.code(), GetError::Filter(e) => e.code(), GetError::Limit(e) => e.code(), GetError::Panic(_) => ErrorCodes::Aborted, @@ -81,62 +75,47 @@ type GetResult = Result; /// /// # Pipeline /// ```text -/// ┌────────────┐ -/// │ │ -/// ┌───────────┤ on_start ├────────────────┐ -/// │ │ │ │ -/// │ └────────────┘ │ -/// │ │ -/// ▼ ▼ -/// ┌────────────────────┐ ┌────────────────────────┐ -/// │ │ │ │ -/// │ FetchLogOperator │ │ FetchSegmentOperator │ -/// │ │ │ │ -/// └────────┬───────────┘ └────────────────┬───────┘ -/// │ │ -/// │ │ -/// │ ┌─────────────────────────────┐ │ -/// │ │ │ │ -/// └────►│ try_start_filter_operator │◄────┘ -/// │ │ -/// └────────────┬────────────────┘ -/// │ -/// ▼ -/// ┌───────────────────┐ -/// │ │ -/// │ FilterOperator │ -/// │ │ -/// └─────────┬─────────┘ -/// │ -/// ▼ -/// ┌─────────────────┐ -/// │ │ -/// │ LimitOperator │ -/// │ │ -/// └────────┬────────┘ -/// │ -/// ▼ -/// ┌──────────────────────┐ -/// │ │ -/// │ ProjectionOperator │ -/// │ │ -/// └──────────┬───────────┘ -/// │ -/// ▼ -/// ┌──────────────────┐ -/// │ │ -/// │ result_channel │ -/// │ │ -/// └──────────────────┘ +/// ┌────────────┐ +/// │ │ +/// │ on_start │ +/// │ │ +/// └──────┬─────┘ +/// │ +/// ▼ +/// ┌────────────────────┐ +/// │ │ +/// │ FetchLogOperator │ +/// │ │ +/// └─────────┬──────────┘ +/// │ +/// ▼ +/// ┌───────────────────┐ +/// │ │ +/// │ FilterOperator │ +/// │ │ +/// └─────────┬─────────┘ +/// │ +/// ▼ +/// ┌─────────────────┐ +/// │ │ +/// │ LimitOperator │ +/// │ │ +/// └────────┬────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ │ +/// │ ProjectionOperator │ +/// │ │ +/// └──────────┬───────────┘ +/// │ +/// ▼ +/// ┌──────────────────┐ +/// │ │ +/// │ result_channel │ +/// │ │ +/// └──────────────────┘ /// ``` -/// -/// # State tracking -/// As suggested by the pipeline diagram above, the orchestrator only need to -/// keep track of the outputs from `FetchLogOperator` and `FetchSegmentOperator`. -/// The orchestrator invokes `try_start_filter_operator` when it receives output -/// from either operators, and if both outputs are present it composes the input -/// for `FilterOperator` and proceeds with execution. The outputs of other -/// operators are directly forwarded without being tracked by the orchestrator. #[derive(Debug)] pub struct GetOrchestrator { // Orchestrator parameters @@ -144,13 +123,14 @@ pub struct GetOrchestrator { dispatcher: ComponentHandle, queue: usize, - // Fetch logs and segments + // Collection segments + collection_and_segments: CollectionAndSegments, + + // Fetch logs fetch_log: FetchLogOperator, - fetch_segment: FetchSegmentOperator, - // Fetch output - fetch_log_output: Option, - fetch_segment_output: Option, + // Fetched logs + fetched_logs: Option, // Pipelined operators filter: FilterOperator, @@ -167,8 +147,8 @@ impl GetOrchestrator { blockfile_provider: BlockfileProvider, dispatcher: ComponentHandle, queue: usize, + collection_and_segments: CollectionAndSegments, fetch_log: FetchLogOperator, - fetch_segment: FetchSegmentOperator, filter: FilterOperator, limit: LimitOperator, projection: ProjectionOperator, @@ -177,10 +157,9 @@ impl GetOrchestrator { blockfile_provider, dispatcher, queue, + collection_and_segments, fetch_log, - fetch_segment, - fetch_log_output: None, - fetch_segment_output: None, + fetched_logs: None, filter, limit, projection, @@ -205,28 +184,6 @@ impl GetOrchestrator { tracing::error!("Error running orchestrator: {}", &get_err); terminate_with_error(self.result_channel.take(), get_err, ctx); } - - /// Try to start the filter operator once both `FetchLogOperator` and `FetchSegmentOperator` completes - async fn try_start_filter_operator(&mut self, ctx: &ComponentContext) { - if let (Some(logs), Some(segments)) = ( - self.fetch_log_output.as_ref(), - self.fetch_segment_output.as_ref(), - ) { - let task = wrap( - Box::new(self.filter.clone()), - FilterInput { - logs: logs.clone(), - blockfile_provider: self.blockfile_provider.clone(), - metadata_segment: segments.metadata_segment.clone(), - record_segment: segments.record_segment.clone(), - }, - ctx.receiver(), - ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } - } - } } #[async_trait] @@ -240,16 +197,8 @@ impl Component for GetOrchestrator { } async fn on_start(&mut self, ctx: &ComponentContext) { - let log_task = wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver()); - let segment_task = wrap(Box::new(self.fetch_segment.clone()), (), ctx.receiver()); - if let Err(err) = self.dispatcher.send(log_task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - return; - } else if let Err(err) = self - .dispatcher - .send(segment_task, Some(Span::current())) - .await - { + let task = wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver()); + if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { self.terminate_with_error(ctx, err); return; } @@ -272,29 +221,22 @@ impl Handler> for GetOrchestrator { return; } }; - self.fetch_log_output = Some(output); - self.try_start_filter_operator(ctx).await; - } -} -#[async_trait] -impl Handler> for GetOrchestrator { - type Result = (); + self.fetched_logs = Some(output.clone()); - async fn handle( - &mut self, - message: TaskResult, - ctx: &ComponentContext, - ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; - self.fetch_segment_output = Some(output); - self.try_start_filter_operator(ctx).await; + let task = wrap( + Box::new(self.filter.clone()), + FilterInput { + logs: output, + blockfile_provider: self.blockfile_provider.clone(), + metadata_segment: self.collection_and_segments.metadata_segment.clone(), + record_segment: self.collection_and_segments.record_segment.clone(), + }, + ctx.receiver(), + ); + if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { + self.terminate_with_error(ctx, err); + } } } @@ -318,17 +260,12 @@ impl Handler> for GetOrchestrator { Box::new(self.limit.clone()), LimitInput { logs: self - .fetch_log_output + .fetched_logs .as_ref() .expect("FetchLogOperator should have finished already") .clone(), blockfile_provider: self.blockfile_provider.clone(), - record_segment: self - .fetch_segment_output - .as_ref() - .expect("FetchSegmentOperator should have finished already") - .record_segment - .clone(), + record_segment: self.collection_and_segments.record_segment.clone(), log_offset_ids: output.log_offset_ids, compact_offset_ids: output.compact_offset_ids, }, @@ -357,24 +294,21 @@ impl Handler> for GetOrchestrator { } }; + let input = ProjectionInput { + logs: self + .fetched_logs + .as_ref() + .expect("FetchLogOperator should have finished already") + .clone(), + blockfile_provider: self.blockfile_provider.clone(), + record_segment: self.collection_and_segments.record_segment.clone(), + offset_ids: output.offset_ids.iter().collect(), + }; + // Prefetch records before projection let prefetch_task = wrap( Box::new(PrefetchRecordOperator {}), - PrefetchRecordInput { - logs: self - .fetch_log_output - .as_ref() - .expect("FetchLogOperator should have finished already") - .clone(), - blockfile_provider: self.blockfile_provider.clone(), - record_segment: self - .fetch_segment_output - .as_ref() - .expect("FetchSegmentOperator should have finished already") - .record_segment - .clone(), - offset_ids: output.offset_ids.iter().collect(), - }, + input.clone(), ctx.receiver(), ); if let Err(err) = self @@ -385,25 +319,7 @@ impl Handler> for GetOrchestrator { self.terminate_with_error(ctx, err); } - let task = wrap( - Box::new(self.projection.clone()), - ProjectionInput { - logs: self - .fetch_log_output - .as_ref() - .expect("FetchLogOperator should have finished already") - .clone(), - blockfile_provider: self.blockfile_provider.clone(), - record_segment: self - .fetch_segment_output - .as_ref() - .expect("FetchSegmentOperator should have finished already") - .record_segment - .clone(), - offset_ids: output.offset_ids.into_iter().collect(), - }, - ctx.receiver(), - ); + let task = wrap(Box::new(self.projection.clone()), input, ctx.receiver()); if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { self.terminate_with_error(ctx, err); } diff --git a/rust/worker/src/execution/orchestration/knn.rs b/rust/worker/src/execution/orchestration/knn.rs index dfc533c6690..c57f6d9605e 100644 --- a/rust/worker/src/execution/orchestration/knn.rs +++ b/rust/worker/src/execution/orchestration/knn.rs @@ -27,8 +27,8 @@ use crate::{ use super::knn_filter::{KnnError, KnnFilterOutput, KnnResult}; -/// The `knn` module contains two orchestrator: `KnnFilterOrchestrator` and `KnnOrchestrator`. -/// When used together, they carry out the evaluation of a `.query(...)` query +/// The `KnnOrchestrator` finds the nearest neighbor of a target embedding given the search domain. +/// When used together with `KnnFilterOrchestrator`, they evaluate a `.query(...)` query /// for the user. We breakdown the evaluation into two parts because a `.query(...)` /// is inherently multiple queries sharing the same filter criteria. Thus we first evaluate /// the filter criteria with `KnnFilterOrchestrator`. Then we spawn a `KnnOrchestrator` for each @@ -38,53 +38,23 @@ use super::knn_filter::{KnnError, KnnFilterOutput, KnnResult}; /// /// # Pipeline /// ```text -/// │ -/// │ -/// │ -/// ┌──────────────────────────── │ ───────────────────────────────┐ -/// │ ▼ │ -/// │ ┌────────────┐ KnnFilterOrchestrator │ -/// │ │ │ │ -/// │ ┌───────────┤ on_start ├────────────────┐ │ -/// │ │ │ │ │ │ -/// │ │ └────────────┘ │ │ -/// │ │ │ │ -/// │ ▼ ▼ │ -/// │ ┌────────────────────┐ ┌────────────────────────┐ │ -/// │ │ │ │ │ │ -/// │ │ FetchLogOperator │ │ FetchSegmentOperator │ │ -/// │ │ │ │ │ │ -/// │ └────────┬───────────┘ └────────────────┬───────┘ │ -/// │ │ │ │ -/// │ │ │ │ -/// │ │ ┌─────────────────────────────┐ │ │ -/// │ │ │ │ │ │ -/// │ └────►│ try_start_filter_operator │◄────┘ │ -/// │ │ │ │ -/// │ └────────────┬────────────────┘ │ -/// │ │ │ -/// │ ▼ │ -/// │ ┌───────────────────┐ │ -/// │ │ │ │ -/// │ │ FilterOperator │ │ -/// │ │ │ │ -/// │ └─────────┬─────────┘ │ -/// │ │ │ -/// │ ▼ │ -/// │ ┌──────────────────┐ │ -/// │ │ │ │ -/// │ │ result_channel │ │ -/// │ │ │ │ -/// │ └────────┬─────────┘ │ -/// │ │ │ -/// └──────────────────────────── │ ───────────────────────────────┘ -/// │ -/// │ -/// │ -/// ┌──────────────────────────────────┴─────────────────────────────────────┐ -/// │ │ -/// │ ... One branch per embedding ... │ -/// │ │ +/// │ +/// │ +/// │ +/// │ +/// ▼ +/// ┌───────────────────────┐ +/// │ │ +/// │ KnnFilterOrchestrator │ +/// │ │ +/// └───────────┬───────────┘ +/// │ +/// │ +/// │ +/// ┌──────────────────────────────────┴─────────────────────────────────────┐ +/// │ │ +/// │ ... One branch per embedding ... │ +/// │ │ /// ┌──────────────────── │ ─────────────────────┐ ┌──────────────────── │ ─────────────────────┐ /// │ ▼ │ │ ▼ │ /// │ ┌────────────┐ KnnOrchestrator │ │ ┌────────────┐ KnnOrchestrator │ @@ -129,27 +99,18 @@ use super::knn_filter::{KnnError, KnnFilterOutput, KnnResult}; /// │ └────────┬─────────┘ │ │ └────────┬─────────┘ │ /// │ │ │ │ │ │ /// └──────────────────── │ ─────────────────────┘ └──────────────────── │ ─────────────────────┘ -/// │ │ -/// │ │ -/// │ │ -/// │ ┌────────────────┐ │ -/// │ │ │ │ -/// └──────────────────────────►│ try_join_all │◄──────────────────────────┘ -/// │ │ -/// └───────┬────────┘ -/// │ -/// │ -/// ▼ +/// │ │ +/// │ │ +/// │ │ +/// │ ┌────────────────┐ │ +/// │ │ │ │ +/// └──────────────────────────►│ try_join_all │◄──────────────────────────┘ +/// │ │ +/// └───────┬────────┘ +/// │ +/// │ +/// ▼ /// ``` -/// -/// # State tracking -/// Similar to the `GetOrchestrator`, the `KnnFilterOrchestrator` need to keep track of the outputs from -/// `FetchLogOperator` and `FetchSegmentOperator`. For `KnnOrchestrator`, it needs to track the outputs from -/// `KnnLogOperator` and `KnnHnswOperator`. It invokes `try_start_knn_merge_operator` when it receives outputs -/// from either operators, and if both outputs are present it composes the input for `KnnMergeOperator` and -/// proceeds with execution. The outputs of other operators are directly forwarded without being tracked -/// by the orchestrator. - #[derive(Debug)] pub struct KnnOrchestrator { // Orchestrator parameters diff --git a/rust/worker/src/execution/orchestration/knn_filter.rs b/rust/worker/src/execution/orchestration/knn_filter.rs index f634e11c7a5..4ccc549d789 100644 --- a/rust/worker/src/execution/orchestration/knn_filter.rs +++ b/rust/worker/src/execution/orchestration/knn_filter.rs @@ -2,7 +2,7 @@ use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::HnswIndexProvider; -use chroma_types::Segment; +use chroma_types::{CollectionAndSegments, Segment}; use thiserror::Error; use tokio::sync::oneshot::{self, error::RecvError, Sender}; use tonic::async_trait; @@ -14,7 +14,6 @@ use crate::{ operator::{wrap, TaskError, TaskResult}, operators::{ fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}, - fetch_segment::{FetchSegmentError, FetchSegmentOperator, FetchSegmentOutput}, filter::{FilterError, FilterInput, FilterOperator, FilterOutput}, knn_hnsw::KnnHnswError, knn_log::KnnLogError, @@ -38,12 +37,8 @@ use crate::{ pub enum KnnError { #[error("Error sending message through channel: {0}")] Channel(#[from] ChannelError), - #[error("Empty collection")] - EmptyCollection, #[error("Error running Fetch Log Operator: {0}")] FetchLog(#[from] FetchLogError), - #[error("Error running Fetch Segment Operator: {0}")] - FetchSegment(#[from] FetchSegmentError), #[error("Error running Filter Operator: {0}")] Filter(#[from] FilterError), #[error("Error creating hnsw segment reader: {0}")] @@ -74,9 +69,7 @@ impl ChromaError for KnnError { fn code(&self) -> ErrorCodes { match self { KnnError::Channel(e) => e.code(), - KnnError::EmptyCollection => ErrorCodes::Internal, KnnError::FetchLog(e) => e.code(), - KnnError::FetchSegment(e) => e.code(), KnnError::Filter(e) => e.code(), KnnError::HnswReader(e) => e.code(), KnnError::KnnLog(e) => e.code(), @@ -118,6 +111,38 @@ pub struct KnnFilterOutput { type KnnFilterResult = Result; +/// The `KnnFilterOrchestrator` chains a sequence of operators in sequence to evaluate +/// the first half of a `.query(...)` query from the user +/// +/// # Pipeline +/// ```text +/// ┌────────────┐ +/// │ │ +/// │ on_start │ +/// │ │ +/// └──────┬─────┘ +/// │ +/// ▼ +/// ┌────────────────────┐ +/// │ │ +/// │ FetchLogOperator │ +/// │ │ +/// └─────────┬──────────┘ +/// │ +/// ▼ +/// ┌───────────────────┐ +/// │ │ +/// │ FilterOperator │ +/// │ │ +/// └─────────┬─────────┘ +/// │ +/// ▼ +/// ┌──────────────────┐ +/// │ │ +/// │ result_channel │ +/// │ │ +/// └──────────────────┘ +/// ``` #[derive(Debug)] pub struct KnnFilterOrchestrator { // Orchestrator parameters @@ -126,13 +151,14 @@ pub struct KnnFilterOrchestrator { hnsw_provider: HnswIndexProvider, queue: usize, - // Fetch logs and segments + // Collection segments + collection_and_segments: CollectionAndSegments, + + // Fetch logs fetch_log: FetchLogOperator, - fetch_segment: FetchSegmentOperator, - // Fetch output - fetch_log_output: Option, - fetch_segment_output: Option, + // Fetched logs + fetched_logs: Option, // Pipelined operators filter: FilterOperator, @@ -147,8 +173,8 @@ impl KnnFilterOrchestrator { dispatcher: ComponentHandle, hnsw_provider: HnswIndexProvider, queue: usize, + collection_and_segments: CollectionAndSegments, fetch_log: FetchLogOperator, - fetch_segment: FetchSegmentOperator, filter: FilterOperator, ) -> Self { Self { @@ -156,10 +182,9 @@ impl KnnFilterOrchestrator { dispatcher, hnsw_provider, queue, + collection_and_segments, fetch_log, - fetch_segment, - fetch_log_output: None, - fetch_segment_output: None, + fetched_logs: None, filter, result_channel: None, } @@ -182,27 +207,6 @@ impl KnnFilterOrchestrator { tracing::error!("Error running orchestrator: {}", &knn_err); terminate_with_error(self.result_channel.take(), knn_err, ctx); } - - async fn try_start_filter_operator(&mut self, ctx: &ComponentContext) { - if let (Some(logs), Some(segments)) = ( - self.fetch_log_output.as_ref(), - self.fetch_segment_output.as_ref(), - ) { - let task = wrap( - Box::new(self.filter.clone()), - FilterInput { - logs: logs.clone(), - blockfile_provider: self.blockfile_provider.clone(), - metadata_segment: segments.metadata_segment.clone(), - record_segment: segments.record_segment.clone(), - }, - ctx.receiver(), - ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } - } - } } #[async_trait] @@ -216,16 +220,10 @@ impl Component for KnnFilterOrchestrator { } async fn on_start(&mut self, ctx: &ComponentContext) { - let log_task = wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver()); - let segment_task = wrap(Box::new(self.fetch_segment.clone()), (), ctx.receiver()); - if let Err(err) = self.dispatcher.send(log_task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } else if let Err(err) = self - .dispatcher - .send(segment_task, Some(Span::current())) - .await - { + let task = wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver()); + if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { self.terminate_with_error(ctx, err); + return; } } } @@ -246,37 +244,22 @@ impl Handler> for KnnFilterOrchestrato return; } }; - self.fetch_log_output = Some(output); - self.try_start_filter_operator(ctx).await; - } -} - -#[async_trait] -impl Handler> for KnnFilterOrchestrator { - type Result = (); - async fn handle( - &mut self, - message: TaskResult, - ctx: &ComponentContext, - ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; + self.fetched_logs = Some(output.clone()); - // If dimension is not set and segment is uninitialized, we assume - // this is a query on empty collection, so we return early here - if output.collection.dimension.is_none() && output.vector_segment.file_path.is_empty() { - self.terminate_with_error(ctx, KnnError::EmptyCollection); - return; + let task = wrap( + Box::new(self.filter.clone()), + FilterInput { + logs: output, + blockfile_provider: self.blockfile_provider.clone(), + metadata_segment: self.collection_and_segments.metadata_segment.clone(), + record_segment: self.collection_and_segments.record_segment.clone(), + }, + ctx.receiver(), + ); + if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { + self.terminate_with_error(ctx, err); } - - self.fetch_segment_output = Some(output); - self.try_start_filter_operator(ctx).await; } } @@ -296,26 +279,23 @@ impl Handler> for KnnFilterOrchestrator { return; } }; - let segments = self - .fetch_segment_output - .take() - .expect("FetchSegmentOperator should have finished already"); - let collection_dimension = match segments.collection.dimension { + let collection_dimension = match self.collection_and_segments.collection.dimension { Some(dimension) => dimension as u32, None => { self.terminate_with_error(ctx, KnnError::NoCollectionDimension); return; } }; - let distance_function = match distance_function_from_segment(&segments.vector_segment) { - Ok(distance_function) => distance_function, - Err(_) => { - self.terminate_with_error(ctx, KnnError::InvalidDistanceFunction); - return; - } - }; + let distance_function = + match distance_function_from_segment(&self.collection_and_segments.vector_segment) { + Ok(distance_function) => distance_function, + Err(_) => { + self.terminate_with_error(ctx, KnnError::InvalidDistanceFunction); + return; + } + }; let hnsw_reader = match DistributedHNSWSegmentReader::from_segment( - &segments.vector_segment, + &self.collection_and_segments.vector_segment, collection_dimension as usize, self.hnsw_provider.clone(), ) @@ -335,14 +315,14 @@ impl Handler> for KnnFilterOrchestrator { if chan .send(Ok(KnnFilterOutput { logs: self - .fetch_log_output + .fetched_logs .take() .expect("FetchLogOperator should have finished already"), distance_function, filter_output: output, hnsw_reader, - record_segment: segments.record_segment, - vector_segment: segments.vector_segment, + record_segment: self.collection_and_segments.record_segment.clone(), + vector_segment: self.collection_and_segments.vector_segment.clone(), dimension: collection_dimension as usize, })) .is_err() diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index da550483fc7..6e68734be5d 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -1,4 +1,4 @@ -use std::{iter::once, str::FromStr}; +use std::iter::once; use chroma_blockstore::provider::BlockfileProvider; use chroma_config::Configurable; @@ -9,7 +9,7 @@ use chroma_types::{ self, query_executor_server::QueryExecutor, CountPlan, CountResult, GetPlan, GetResult, KnnBatchResult, KnnPlan, }, - CollectionUuid, SegmentUuid, + CollectionAndSegments, }; use futures::{stream, StreamExt, TryStreamExt}; use tokio::signal::unix::{signal, SignalKind}; @@ -20,14 +20,9 @@ use crate::{ config::QueryServiceConfig, execution::{ dispatcher::Dispatcher, - operators::{ - fetch_log::FetchLogOperator, fetch_segment::FetchSegmentOperator, - knn_projection::KnnProjectionOperator, - }, + operators::{fetch_log::FetchLogOperator, knn_projection::KnnProjectionOperator}, orchestration::{ - get::GetOrchestrator, - knn::KnnOrchestrator, - knn_filter::{KnnError, KnnFilterOrchestrator}, + get::GetOrchestrator, knn::KnnOrchestrator, knn_filter::KnnFilterOrchestrator, CountQueryOrchestrator, }, }, @@ -136,44 +131,17 @@ impl WorkerServer { self.system = Some(system); } - fn decompose_proto_scan( - &self, - scan: chroma_proto::ScanOperator, - ) -> Result<(FetchLogOperator, FetchSegmentOperator), Status> { - let collection = scan - .collection - .ok_or(Status::invalid_argument("Invalid Collection"))?; - - let collection_uuid = CollectionUuid::from_str(&collection.id) - .map_err(|_| Status::invalid_argument("Invalid Collection UUID"))?; - - let metadata_uuid = - SegmentUuid::from_str(&scan.metadata.map(|seg| seg.id).unwrap_or(scan.metadata_id))?; - let record_uuid = - SegmentUuid::from_str(&scan.record.map(|seg| seg.id).unwrap_or(scan.record_id))?; - let vector_uuid = - SegmentUuid::from_str(&scan.knn.map(|seg| seg.id).unwrap_or(scan.knn_id))?; - - Ok(( - FetchLogOperator { - log_client: self.log.clone(), - // TODO: Make this configurable - batch_size: 100, - // The collection log position is inclusive, and we want to start from the next log - // Note that we query using the incoming log position this is critical for correctness - start_log_offset_id: collection.log_position as u32 + 1, - maximum_fetch_count: None, - collection_uuid, - }, - FetchSegmentOperator { - sysdb: self.sysdb.clone(), - collection_uuid, - collection_version: collection.version as u32, - metadata_uuid, - record_uuid, - vector_uuid, - }, - )) + fn fetch_log(&self, collection_and_segments: &CollectionAndSegments) -> FetchLogOperator { + FetchLogOperator { + log_client: self.log.clone(), + // TODO: Make this configurable + batch_size: 100, + // The collection log position is inclusive, and we want to start from the next log + // Note that we query using the incoming log position this is critical for correctness + start_log_offset_id: collection_and_segments.collection.log_position as u32 + 1, + maximum_fetch_count: None, + collection_uuid: collection_and_segments.collection.collection_id, + } } async fn orchestrate_count( @@ -185,15 +153,13 @@ impl WorkerServer { .scan .ok_or(Status::invalid_argument("Invalid Scan Operator"))?; - let collection = &scan - .collection - .ok_or(Status::invalid_argument("Invalid collection"))?; + let collection_and_segments = CollectionAndSegments::try_from(scan)?; + let collection = &collection_and_segments.collection; let count_orchestrator = CountQueryOrchestrator::new( self.clone_system()?, - &SegmentUuid::from_str(&scan.metadata.map(|seg| seg.id).unwrap_or(scan.metadata_id))?.0, - &CollectionUuid::from_str(&collection.id) - .map_err(|e| Status::invalid_argument(e.to_string()))?, + &collection_and_segments.metadata_segment.id.0, + &collection.collection_id, self.log.clone(), self.sysdb.clone(), self.clone_dispatcher()?, @@ -216,7 +182,8 @@ impl WorkerServer { .scan .ok_or(Status::invalid_argument("Invalid Scan Operator"))?; - let (fetch_log_operator, fetch_segment_operator) = self.decompose_proto_scan(scan)?; + let collection_and_segments = scan.try_into()?; + let fetch_log = self.fetch_log(&collection_and_segments); let filter = get_inner .filter @@ -235,8 +202,8 @@ impl WorkerServer { self.clone_dispatcher()?, // TODO: Make this configurable 1000, - fetch_log_operator, - fetch_segment_operator, + collection_and_segments, + fetch_log, filter.try_into()?, limit.into(), projection.into(), @@ -261,7 +228,9 @@ impl WorkerServer { .scan .ok_or(Status::invalid_argument("Invalid Scan Operator"))?; - let (fetch_log_operator, fetch_segment_operator) = self.decompose_proto_scan(scan)?; + let collection_and_segments = scan.try_into()?; + + let fetch_log = self.fetch_log(&collection_and_segments); let filter = knn_inner .filter @@ -281,27 +250,32 @@ impl WorkerServer { return Ok(Response::new(to_proto_knn_batch_result(Vec::new())?)); } + // If dimension is not set and segment is uninitialized, we assume + // this is a query on empty collection, so we return early here + if collection_and_segments.collection.dimension.is_none() + && collection_and_segments.vector_segment.file_path.is_empty() + { + return Ok(Response::new(to_proto_knn_batch_result( + once(Default::default()) + .cycle() + .take(knn.embeddings.len()) + .collect(), + )?)); + } + let knn_filter_orchestrator = KnnFilterOrchestrator::new( self.blockfile_provider.clone(), dispatcher.clone(), self.hnsw_index_provider.clone(), // TODO: Make this configurable 1000, - fetch_log_operator, - fetch_segment_operator, + collection_and_segments, + fetch_log, filter.try_into()?, ); let matching_records = match knn_filter_orchestrator.run(system.clone()).await { Ok(output) => output, - Err(KnnError::EmptyCollection) => { - return Ok(Response::new(to_proto_knn_batch_result( - once(Default::default()) - .cycle() - .take(knn.embeddings.len()) - .collect(), - )?)); - } Err(e) => { return Err(Status::new(e.code().into(), e.to_string())); } @@ -472,21 +446,15 @@ mod tests { chroma_proto::ScanOperator { collection: Some(chroma_proto::Collection { id: collection_id.clone(), - name: "Test-Collection".to_string(), + name: "test-collection".to_string(), configuration_json_str: String::new(), metadata: None, dimension: None, - tenant: "Test-Tenant".to_string(), - database: "Test-Database".to_string(), + tenant: "test-tenant".to_string(), + database: "test-database".to_string(), log_position: 0, version: 0, }), - // Deprecated - knn_id: "".to_string(), - // Deprecated - metadata_id: "".to_string(), - // Deprecated - record_id: "".to_string(), knn: Some(chroma_proto::Segment { id: Uuid::new_v4().to_string(), r#type: "urn:chroma:segment/vector/hnsw-distributed".to_string(), @@ -567,29 +535,6 @@ mod tests { async fn validate_get_plan() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan_operator = scan(); - let request = chroma_proto::GetPlan { - scan: Some(scan_operator.clone()), - filter: Some(chroma_proto::FilterOperator { - ids: None, - r#where: None, - where_document: None, - }), - limit: Some(chroma_proto::LimitOperator { - skip: 0, - fetch: None, - }), - projection: Some(chroma_proto::ProjectionOperator { - document: false, - embedding: false, - metadata: false, - }), - }; - - // segment or collection not found - let response = executor.get(request.clone()).await; - assert!(response.is_err()); - assert_eq!(response.unwrap_err().code(), tonic::Code::NotFound); - let request = chroma_proto::GetPlan { scan: Some(scan_operator.clone()), filter: None, @@ -766,11 +711,6 @@ mod tests { assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); - assert!( - err.message().to_lowercase().contains("uuid"), - "{}", - err.message() - ); } #[tokio::test] @@ -795,11 +735,6 @@ mod tests { assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); - assert!( - err.message().to_lowercase().contains("uuid"), - "{}", - err.message() - ); } #[tokio::test] @@ -823,11 +758,6 @@ mod tests { assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); - assert!( - err.message().to_lowercase().contains("uuid"), - "{}", - err.message() - ); } #[tokio::test] @@ -851,10 +781,5 @@ mod tests { assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); - assert!( - err.message().to_lowercase().contains("uuid"), - "{}", - err.message() - ); } }