From e02d7a5f362a5afac0c2ac7b9afd94483a4d911e Mon Sep 17 00:00:00 2001 From: Sicheng Pan Date: Thu, 5 Dec 2024 17:04:06 -0800 Subject: [PATCH] [ENH] Propagate segment information from frontend to query node --- chromadb/proto/convert.py | 6 +- chromadb/proto/query_executor_pb2.py | 66 +++++------ chromadb/proto/query_executor_pb2.pyi | 16 +-- idl/chromadb/proto/query_executor.proto | 6 +- rust/types/src/segment.rs | 7 ++ rust/worker/src/server.rs | 140 +++++++++++++++++++----- 6 files changed, 165 insertions(+), 76 deletions(-) diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 2f2a16cf20b..51d30bc608a 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -573,9 +573,9 @@ def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDoc def to_proto_scan(scan: Scan) -> query_pb.ScanOperator: return query_pb.ScanOperator( collection=to_proto_collection(scan.collection), - knn_id=scan.knn["id"].hex, - metadata_id=scan.metadata["id"].hex, - record_id=scan.record["id"].hex, + knn=to_proto_segment(scan.knn), + metadata=to_proto_segment(scan.metadata), + record=to_proto_segment(scan.record), ) diff --git a/chromadb/proto/query_executor_pb2.py b/chromadb/proto/query_executor_pb2.py index d89a21c189a..3cf16653463 100644 --- a/chromadb/proto/query_executor_pb2.py +++ b/chromadb/proto/query_executor_pb2.py @@ -14,7 +14,7 @@ from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#chromadb/proto/query_executor.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\"n\n\x0cScanOperator\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0e\n\x06knn_id\x18\x02 \x01(\t\x12\x13\n\x0bmetadata_id\x18\x03 \x01(\t\x12\x11\n\trecord_id\x18\x04 \x01(\t\"\xaf\x01\n\x0e\x46ilterOperator\x12!\n\x03ids\x18\x01 \x01(\x0b\x32\x0f.chroma.UserIdsH\x00\x88\x01\x01\x12!\n\x05where\x18\x02 \x01(\x0b\x32\r.chroma.WhereH\x01\x88\x01\x01\x12\x32\n\x0ewhere_document\x18\x03 \x01(\x0b\x32\x15.chroma.WhereDocumentH\x02\x88\x01\x01\x42\x06\n\x04_idsB\x08\n\x06_whereB\x11\n\x0f_where_document\"@\n\x0bKNNOperator\x12\"\n\nembeddings\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\r\n\x05\x66\x65tch\x18\x02 \x01(\r\";\n\rLimitOperator\x12\x0c\n\x04skip\x18\x01 \x01(\r\x12\x12\n\x05\x66\x65tch\x18\x02 \x01(\rH\x00\x88\x01\x01\x42\x08\n\x06_fetch\"K\n\x12ProjectionOperator\x12\x10\n\x08\x64ocument\x18\x01 \x01(\x08\x12\x11\n\tembedding\x18\x02 \x01(\x08\x12\x10\n\x08metadata\x18\x03 \x01(\x08\"Y\n\x15KNNProjectionOperator\x12.\n\nprojection\x18\x01 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\x12\x10\n\x08\x64istance\x18\x02 \x01(\x08\"/\n\tCountPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\"\x1c\n\x0b\x43ountResult\x12\r\n\x05\x63ount\x18\x01 \x01(\r\"\xab\x01\n\x07GetPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12$\n\x05limit\x18\x03 \x01(\x0b\x32\x15.chroma.LimitOperator\x12.\n\nprojection\x18\x04 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\"\xb4\x01\n\x10ProjectionRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x08\x64ocument\x18\x02 \x01(\tH\x00\x88\x01\x01\x12&\n\tembedding\x18\x03 \x01(\x0b\x32\x0e.chroma.VectorH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x0b\n\t_documentB\x0c\n\n_embeddingB\x0b\n\t_metadata\"6\n\tGetResult\x12)\n\x07records\x18\x01 \x03(\x0b\x32\x18.chroma.ProjectionRecord\"\xaa\x01\n\x07KNNPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12 \n\x03knn\x18\x03 \x01(\x0b\x32\x13.chroma.KNNOperator\x12\x31\n\nprojection\x18\x04 \x01(\x0b\x32\x1d.chroma.KNNProjectionOperator\"c\n\x13KNNProjectionRecord\x12(\n\x06record\x18\x01 \x01(\x0b\x32\x18.chroma.ProjectionRecord\x12\x15\n\x08\x64istance\x18\x02 \x01(\x02H\x00\x88\x01\x01\x42\x0b\n\t_distance\"9\n\tKNNResult\x12,\n\x07records\x18\x01 \x03(\x0b\x32\x1b.chroma.KNNProjectionRecord\"4\n\x0eKNNBatchResult\x12\"\n\x07results\x18\x01 \x03(\x0b\x32\x11.chroma.KNNResult2\xa1\x01\n\rQueryExecutor\x12\x31\n\x05\x43ount\x12\x11.chroma.CountPlan\x1a\x13.chroma.CountResult\"\x00\x12+\n\x03Get\x12\x0f.chroma.GetPlan\x1a\x11.chroma.GetResult\"\x00\x12\x30\n\x03KNN\x12\x0f.chroma.KNNPlan\x1a\x16.chroma.KNNBatchResult\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#chromadb/proto/query_executor.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\"\x98\x01\n\x0cScanOperator\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1c\n\x03knn\x18\x02 \x01(\x0b\x32\x0f.chroma.Segment\x12!\n\x08metadata\x18\x03 \x01(\x0b\x32\x0f.chroma.Segment\x12\x1f\n\x06record\x18\x04 \x01(\x0b\x32\x0f.chroma.Segment\"\xaf\x01\n\x0e\x46ilterOperator\x12!\n\x03ids\x18\x01 \x01(\x0b\x32\x0f.chroma.UserIdsH\x00\x88\x01\x01\x12!\n\x05where\x18\x02 \x01(\x0b\x32\r.chroma.WhereH\x01\x88\x01\x01\x12\x32\n\x0ewhere_document\x18\x03 \x01(\x0b\x32\x15.chroma.WhereDocumentH\x02\x88\x01\x01\x42\x06\n\x04_idsB\x08\n\x06_whereB\x11\n\x0f_where_document\"@\n\x0bKNNOperator\x12\"\n\nembeddings\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\r\n\x05\x66\x65tch\x18\x02 \x01(\r\";\n\rLimitOperator\x12\x0c\n\x04skip\x18\x01 \x01(\r\x12\x12\n\x05\x66\x65tch\x18\x02 \x01(\rH\x00\x88\x01\x01\x42\x08\n\x06_fetch\"K\n\x12ProjectionOperator\x12\x10\n\x08\x64ocument\x18\x01 \x01(\x08\x12\x11\n\tembedding\x18\x02 \x01(\x08\x12\x10\n\x08metadata\x18\x03 \x01(\x08\"Y\n\x15KNNProjectionOperator\x12.\n\nprojection\x18\x01 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\x12\x10\n\x08\x64istance\x18\x02 \x01(\x08\"/\n\tCountPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\"\x1c\n\x0b\x43ountResult\x12\r\n\x05\x63ount\x18\x01 \x01(\r\"\xab\x01\n\x07GetPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12$\n\x05limit\x18\x03 \x01(\x0b\x32\x15.chroma.LimitOperator\x12.\n\nprojection\x18\x04 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\"\xb4\x01\n\x10ProjectionRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x08\x64ocument\x18\x02 \x01(\tH\x00\x88\x01\x01\x12&\n\tembedding\x18\x03 \x01(\x0b\x32\x0e.chroma.VectorH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x0b\n\t_documentB\x0c\n\n_embeddingB\x0b\n\t_metadata\"6\n\tGetResult\x12)\n\x07records\x18\x01 \x03(\x0b\x32\x18.chroma.ProjectionRecord\"\xaa\x01\n\x07KNNPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12 \n\x03knn\x18\x03 \x01(\x0b\x32\x13.chroma.KNNOperator\x12\x31\n\nprojection\x18\x04 \x01(\x0b\x32\x1d.chroma.KNNProjectionOperator\"c\n\x13KNNProjectionRecord\x12(\n\x06record\x18\x01 \x01(\x0b\x32\x18.chroma.ProjectionRecord\x12\x15\n\x08\x64istance\x18\x02 \x01(\x02H\x00\x88\x01\x01\x42\x0b\n\t_distance\"9\n\tKNNResult\x12,\n\x07records\x18\x01 \x03(\x0b\x32\x1b.chroma.KNNProjectionRecord\"4\n\x0eKNNBatchResult\x12\"\n\x07results\x18\x01 \x03(\x0b\x32\x11.chroma.KNNResult2\xa1\x01\n\rQueryExecutor\x12\x31\n\x05\x43ount\x12\x11.chroma.CountPlan\x1a\x13.chroma.CountResult\"\x00\x12+\n\x03Get\x12\x0f.chroma.GetPlan\x1a\x11.chroma.GetResult\"\x00\x12\x30\n\x03KNN\x12\x0f.chroma.KNNPlan\x1a\x16.chroma.KNNBatchResult\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -22,36 +22,36 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_SCANOPERATOR']._serialized_start=76 - _globals['_SCANOPERATOR']._serialized_end=186 - _globals['_FILTEROPERATOR']._serialized_start=189 - _globals['_FILTEROPERATOR']._serialized_end=364 - _globals['_KNNOPERATOR']._serialized_start=366 - _globals['_KNNOPERATOR']._serialized_end=430 - _globals['_LIMITOPERATOR']._serialized_start=432 - _globals['_LIMITOPERATOR']._serialized_end=491 - _globals['_PROJECTIONOPERATOR']._serialized_start=493 - _globals['_PROJECTIONOPERATOR']._serialized_end=568 - _globals['_KNNPROJECTIONOPERATOR']._serialized_start=570 - _globals['_KNNPROJECTIONOPERATOR']._serialized_end=659 - _globals['_COUNTPLAN']._serialized_start=661 - _globals['_COUNTPLAN']._serialized_end=708 - _globals['_COUNTRESULT']._serialized_start=710 - _globals['_COUNTRESULT']._serialized_end=738 - _globals['_GETPLAN']._serialized_start=741 - _globals['_GETPLAN']._serialized_end=912 - _globals['_PROJECTIONRECORD']._serialized_start=915 - _globals['_PROJECTIONRECORD']._serialized_end=1095 - _globals['_GETRESULT']._serialized_start=1097 - _globals['_GETRESULT']._serialized_end=1151 - _globals['_KNNPLAN']._serialized_start=1154 - _globals['_KNNPLAN']._serialized_end=1324 - _globals['_KNNPROJECTIONRECORD']._serialized_start=1326 - _globals['_KNNPROJECTIONRECORD']._serialized_end=1425 - _globals['_KNNRESULT']._serialized_start=1427 - _globals['_KNNRESULT']._serialized_end=1484 - _globals['_KNNBATCHRESULT']._serialized_start=1486 - _globals['_KNNBATCHRESULT']._serialized_end=1538 - _globals['_QUERYEXECUTOR']._serialized_start=1541 - _globals['_QUERYEXECUTOR']._serialized_end=1702 + _globals['_SCANOPERATOR']._serialized_start=77 + _globals['_SCANOPERATOR']._serialized_end=229 + _globals['_FILTEROPERATOR']._serialized_start=232 + _globals['_FILTEROPERATOR']._serialized_end=407 + _globals['_KNNOPERATOR']._serialized_start=409 + _globals['_KNNOPERATOR']._serialized_end=473 + _globals['_LIMITOPERATOR']._serialized_start=475 + _globals['_LIMITOPERATOR']._serialized_end=534 + _globals['_PROJECTIONOPERATOR']._serialized_start=536 + _globals['_PROJECTIONOPERATOR']._serialized_end=611 + _globals['_KNNPROJECTIONOPERATOR']._serialized_start=613 + _globals['_KNNPROJECTIONOPERATOR']._serialized_end=702 + _globals['_COUNTPLAN']._serialized_start=704 + _globals['_COUNTPLAN']._serialized_end=751 + _globals['_COUNTRESULT']._serialized_start=753 + _globals['_COUNTRESULT']._serialized_end=781 + _globals['_GETPLAN']._serialized_start=784 + _globals['_GETPLAN']._serialized_end=955 + _globals['_PROJECTIONRECORD']._serialized_start=958 + _globals['_PROJECTIONRECORD']._serialized_end=1138 + _globals['_GETRESULT']._serialized_start=1140 + _globals['_GETRESULT']._serialized_end=1194 + _globals['_KNNPLAN']._serialized_start=1197 + _globals['_KNNPLAN']._serialized_end=1367 + _globals['_KNNPROJECTIONRECORD']._serialized_start=1369 + _globals['_KNNPROJECTIONRECORD']._serialized_end=1468 + _globals['_KNNRESULT']._serialized_start=1470 + _globals['_KNNRESULT']._serialized_end=1527 + _globals['_KNNBATCHRESULT']._serialized_start=1529 + _globals['_KNNBATCHRESULT']._serialized_end=1581 + _globals['_QUERYEXECUTOR']._serialized_start=1584 + _globals['_QUERYEXECUTOR']._serialized_end=1745 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/query_executor_pb2.pyi b/chromadb/proto/query_executor_pb2.pyi index 53483f8445f..aa800835e85 100644 --- a/chromadb/proto/query_executor_pb2.pyi +++ b/chromadb/proto/query_executor_pb2.pyi @@ -7,16 +7,16 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor class ScanOperator(_message.Message): - __slots__ = ["collection", "knn_id", "metadata_id", "record_id"] + __slots__ = ["collection", "knn", "metadata", "record"] COLLECTION_FIELD_NUMBER: _ClassVar[int] - KNN_ID_FIELD_NUMBER: _ClassVar[int] - METADATA_ID_FIELD_NUMBER: _ClassVar[int] - RECORD_ID_FIELD_NUMBER: _ClassVar[int] + KNN_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + RECORD_FIELD_NUMBER: _ClassVar[int] collection: _chroma_pb2.Collection - knn_id: str - metadata_id: str - record_id: str - def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., knn_id: _Optional[str] = ..., metadata_id: _Optional[str] = ..., record_id: _Optional[str] = ...) -> None: ... + knn: _chroma_pb2.Segment + metadata: _chroma_pb2.Segment + record: _chroma_pb2.Segment + def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., knn: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ..., metadata: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ..., record: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ...) -> None: ... class FilterOperator(_message.Message): __slots__ = ["ids", "where", "where_document"] diff --git a/idl/chromadb/proto/query_executor.proto b/idl/chromadb/proto/query_executor.proto index 0149767217c..2dba007e3fd 100644 --- a/idl/chromadb/proto/query_executor.proto +++ b/idl/chromadb/proto/query_executor.proto @@ -6,9 +6,9 @@ import "chromadb/proto/chroma.proto"; message ScanOperator { Collection collection = 1; - string knn_id = 2; - string metadata_id = 3; - string record_id = 4; + Segment knn = 2; + Segment metadata = 3; + Segment record = 4; } message FilterOperator { diff --git a/rust/types/src/segment.rs b/rust/types/src/segment.rs index 29fba56af84..8858365554a 100644 --- a/rust/types/src/segment.rs +++ b/rust/types/src/segment.rs @@ -6,6 +6,7 @@ use crate::chroma_proto; use chroma_error::{ChromaError, ErrorCodes}; use std::{collections::HashMap, str::FromStr}; use thiserror::Error; +use tonic::Status; use uuid::Uuid; /// SegmentUuid is a wrapper around Uuid to provide a type for the segment id. @@ -106,6 +107,12 @@ impl ChromaError for SegmentConversionError { } } +impl From for Status { + fn from(value: SegmentConversionError) -> Self { + Status::invalid_argument(value.to_string()) + } +} + impl TryFrom for Segment { type Error = SegmentConversionError; diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index bda9641b300..f2bafac2c66 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -9,13 +9,12 @@ use chroma_types::{ self, query_executor_server::QueryExecutor, CountPlan, CountResult, GetPlan, GetResult, KnnBatchResult, KnnPlan, }, - CollectionUuid, SegmentUuid, + CollectionUuid, Segment, }; use futures::{stream, StreamExt, TryStreamExt}; use tokio::signal::unix::{signal, SignalKind}; use tonic::{transport::Server, Request, Response, Status}; use tracing::{trace_span, Instrument}; -use uuid::Uuid; use crate::{ config::QueryServiceConfig, @@ -143,14 +142,19 @@ impl WorkerServer { let collection_uuid = CollectionUuid::from_str(&collection.id) .map_err(|_| Status::invalid_argument("Invalid Collection UUID"))?; - let vector_uuid = SegmentUuid::from_str(&scan.knn_id) - .map_err(|_| Status::invalid_argument("Invalid UUID for Vector segment"))?; - - let metadata_uuid = SegmentUuid::from_str(&scan.metadata_id) - .map_err(|_| Status::invalid_argument("Invalid UUID for Metadata segment"))?; + let metadata_segment = scan + .metadata + .ok_or(Status::invalid_argument("Invalid metadata segment"))?; + let record_segment = scan + .record + .ok_or(Status::invalid_argument("Invalid record segment"))?; + let vector_segment = scan + .knn + .ok_or(Status::invalid_argument("Invalid vector segment"))?; - let record_uuid = SegmentUuid::from_str(&scan.record_id) - .map_err(|_| Status::invalid_argument("Invalid UUID for Record segment"))?; + let metadata_uuid = Segment::try_from(metadata_segment)?.id; + let record_uuid = Segment::try_from(record_segment)?.id; + let vector_uuid = Segment::try_from(vector_segment)?.id; Ok(( FetchLogOperator { @@ -189,8 +193,12 @@ impl WorkerServer { let count_orchestrator = CountQueryOrchestrator::new( self.clone_system()?, - &Uuid::parse_str(&scan.metadata_id) - .map_err(|e| Status::invalid_argument(e.to_string()))?, + &Segment::try_from( + scan.metadata + .ok_or(Status::invalid_argument("Invalid metadata segment"))?, + )? + .id + .0, &CollectionUuid::from_str(&collection.id) .map_err(|e| Status::invalid_argument(e.to_string()))?, self.log.clone(), @@ -422,6 +430,8 @@ impl chroma_proto::debug_server::Debug for WorkerServer { #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::execution::dispatcher; use crate::log::log::InMemoryLog; @@ -465,9 +475,10 @@ mod tests { } fn scan() -> chroma_proto::ScanOperator { + let collection_id = Uuid::new_v4().to_string(); chroma_proto::ScanOperator { collection: Some(chroma_proto::Collection { - id: Uuid::new_v4().to_string(), + id: collection_id.clone(), name: "Test-Collection".to_string(), configuration_json_str: String::new(), metadata: None, @@ -477,9 +488,30 @@ mod tests { log_position: 0, version: 0, }), - knn_id: Uuid::new_v4().to_string(), - metadata_id: Uuid::new_v4().to_string(), - record_id: Uuid::new_v4().to_string(), + knn: Some(chroma_proto::Segment { + id: Uuid::new_v4().to_string(), + r#type: "urn:chroma:segment/vector/hnsw-distributed".to_string(), + scope: 0, + collection: collection_id.clone(), + metadata: None, + file_paths: HashMap::new(), + }), + metadata: Some(chroma_proto::Segment { + id: Uuid::new_v4().to_string(), + r#type: "urn:chroma:segment/metadata/blockfile".to_string(), + scope: 1, + collection: collection_id.clone(), + metadata: None, + file_paths: HashMap::new(), + }), + record: Some(chroma_proto::Segment { + id: Uuid::new_v4().to_string(), + r#type: "urn:chroma:segment/record/blockfile".to_string(), + scope: 2, + collection: collection_id.clone(), + metadata: None, + file_paths: HashMap::new(), + }), } } @@ -509,7 +541,19 @@ mod tests { let response = executor.count(request).await; assert_eq!(response.unwrap_err().code(), tonic::Code::NotFound); - scan_operator.metadata_id = "invalid_segment_id".to_string(); + scan_operator.metadata = Some(chroma_proto::Segment { + id: "invalid-metadata-segment-id".to_string(), + r#type: "urn:chroma:segment/metadata/blockfile".to_string(), + scope: 1, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let request = chroma_proto::CountPlan { scan: Some(scan_operator.clone()), }; @@ -567,13 +611,13 @@ mod tests { assert_eq!(response.unwrap_err().code(), tonic::Code::InvalidArgument); scan_operator.collection = Some(chroma_proto::Collection { - id: "Invalid-Collection-ID".to_string(), - name: "Broken-Collection".to_string(), + id: "invalid-collection-iD".to_string(), + name: "broken-collection".to_string(), configuration_json_str: String::new(), metadata: None, dimension: None, - tenant: "Test-Tenant".to_string(), - database: "Test-Database".to_string(), + tenant: "test-tenant".to_string(), + database: "test-database".to_string(), log_position: 0, version: 0, }); @@ -601,7 +645,9 @@ mod tests { assert_eq!(response.unwrap_err().code(), tonic::Code::InvalidArgument); } - fn gen_knn_request(mut scan_operator: Option) -> chroma_proto::KnnPlan { + fn gen_knn_request( + mut scan_operator: Option, + ) -> chroma_proto::KnnPlan { if scan_operator.is_none() { scan_operator = Some(scan()); } @@ -716,13 +762,13 @@ mod tests { async fn validate_knn_plan_scan_collection() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan = scan(); - scan.collection.as_mut().unwrap().id = "Invalid-Collection-ID".to_string(); + scan.collection.as_mut().unwrap().id = "invalid-collection-id".to_string(); let response = executor.knn(gen_knn_request(Some(scan))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); assert!( - err.message().to_lowercase().contains("collection uuid"), + err.message().to_lowercase().contains("uuid"), "{}", err.message() ); @@ -733,13 +779,25 @@ mod tests { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); // invalid vector uuid let mut scan_operator = scan(); - scan_operator.knn_id = "invalid_segment_id".to_string(); + scan_operator.knn = Some(chroma_proto::Segment { + id: "invalid-knn-segment-id".to_string(), + r#type: "urn:chroma:segment/vector/hnsw-distributed".to_string(), + scope: 0, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let response = executor.knn(gen_knn_request(Some(scan_operator))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); assert!( - err.message().to_lowercase().contains("vector"), + err.message().to_lowercase().contains("uuid"), "{}", err.message() ); @@ -749,13 +807,25 @@ mod tests { async fn validate_knn_plan_scan_record() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan_operator = scan(); - scan_operator.record_id = "invalid_record_id".to_string(); + scan_operator.record = Some(chroma_proto::Segment { + id: "invalid-record-segment-id".to_string(), + r#type: "urn:chroma:segment/record/blockfile".to_string(), + scope: 2, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let response = executor.knn(gen_knn_request(Some(scan_operator))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); assert!( - err.message().to_lowercase().contains("record"), + err.message().to_lowercase().contains("uuid"), "{}", err.message() ); @@ -765,13 +835,25 @@ mod tests { async fn validate_knn_plan_scan_metadata() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan_operator = scan(); - scan_operator.metadata_id = "invalid_metadata_id".to_string(); + scan_operator.metadata = Some(chroma_proto::Segment { + id: "invalid-metadata-segment-id".to_string(), + r#type: "urn:chroma:segment/metadata/blockfile".to_string(), + scope: 1, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let response = executor.knn(gen_knn_request(Some(scan_operator))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); assert!( - err.message().to_lowercase().contains("metadata"), + err.message().to_lowercase().contains("uuid"), "{}", err.message() );