From 0907a8905800ae32b7beaf566e4852d359814a15 Mon Sep 17 00:00:00 2001 From: Macronova <60079945+Sicheng-Pan@users.noreply.github.com> Date: Mon, 16 Dec 2024 18:52:33 -0800 Subject: [PATCH] [ENH] Propagate segment information from frontend to query node (#3255) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - N/A - New functionality - Propagate full segment information from frontend to query node ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* N/A --- chromadb/proto/convert.py | 3 + chromadb/proto/query_executor_pb2.py | 66 +++++------ chromadb/proto/query_executor_pb2.pyi | 10 +- idl/chromadb/proto/query_executor.proto | 6 + rust/types/src/segment.rs | 7 ++ rust/worker/src/server.rs | 141 +++++++++++++++++++----- 6 files changed, 168 insertions(+), 65 deletions(-) diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 2f2a16cf20b..2a18475cf98 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -576,6 +576,9 @@ def to_proto_scan(scan: Scan) -> query_pb.ScanOperator: knn_id=scan.knn["id"].hex, metadata_id=scan.metadata["id"].hex, record_id=scan.record["id"].hex, + knn=to_proto_segment(scan.knn), + metadata=to_proto_segment(scan.metadata), + record=to_proto_segment(scan.record), ) diff --git a/chromadb/proto/query_executor_pb2.py b/chromadb/proto/query_executor_pb2.py index d89a21c189a..8131f121198 100644 --- a/chromadb/proto/query_executor_pb2.py +++ b/chromadb/proto/query_executor_pb2.py @@ -14,7 +14,7 @@ from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#chromadb/proto/query_executor.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\"n\n\x0cScanOperator\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0e\n\x06knn_id\x18\x02 \x01(\t\x12\x13\n\x0bmetadata_id\x18\x03 \x01(\t\x12\x11\n\trecord_id\x18\x04 \x01(\t\"\xaf\x01\n\x0e\x46ilterOperator\x12!\n\x03ids\x18\x01 \x01(\x0b\x32\x0f.chroma.UserIdsH\x00\x88\x01\x01\x12!\n\x05where\x18\x02 \x01(\x0b\x32\r.chroma.WhereH\x01\x88\x01\x01\x12\x32\n\x0ewhere_document\x18\x03 \x01(\x0b\x32\x15.chroma.WhereDocumentH\x02\x88\x01\x01\x42\x06\n\x04_idsB\x08\n\x06_whereB\x11\n\x0f_where_document\"@\n\x0bKNNOperator\x12\"\n\nembeddings\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\r\n\x05\x66\x65tch\x18\x02 \x01(\r\";\n\rLimitOperator\x12\x0c\n\x04skip\x18\x01 \x01(\r\x12\x12\n\x05\x66\x65tch\x18\x02 \x01(\rH\x00\x88\x01\x01\x42\x08\n\x06_fetch\"K\n\x12ProjectionOperator\x12\x10\n\x08\x64ocument\x18\x01 \x01(\x08\x12\x11\n\tembedding\x18\x02 \x01(\x08\x12\x10\n\x08metadata\x18\x03 \x01(\x08\"Y\n\x15KNNProjectionOperator\x12.\n\nprojection\x18\x01 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\x12\x10\n\x08\x64istance\x18\x02 \x01(\x08\"/\n\tCountPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\"\x1c\n\x0b\x43ountResult\x12\r\n\x05\x63ount\x18\x01 \x01(\r\"\xab\x01\n\x07GetPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12$\n\x05limit\x18\x03 \x01(\x0b\x32\x15.chroma.LimitOperator\x12.\n\nprojection\x18\x04 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\"\xb4\x01\n\x10ProjectionRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x08\x64ocument\x18\x02 \x01(\tH\x00\x88\x01\x01\x12&\n\tembedding\x18\x03 \x01(\x0b\x32\x0e.chroma.VectorH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x0b\n\t_documentB\x0c\n\n_embeddingB\x0b\n\t_metadata\"6\n\tGetResult\x12)\n\x07records\x18\x01 \x03(\x0b\x32\x18.chroma.ProjectionRecord\"\xaa\x01\n\x07KNNPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12 \n\x03knn\x18\x03 \x01(\x0b\x32\x13.chroma.KNNOperator\x12\x31\n\nprojection\x18\x04 \x01(\x0b\x32\x1d.chroma.KNNProjectionOperator\"c\n\x13KNNProjectionRecord\x12(\n\x06record\x18\x01 \x01(\x0b\x32\x18.chroma.ProjectionRecord\x12\x15\n\x08\x64istance\x18\x02 \x01(\x02H\x00\x88\x01\x01\x42\x0b\n\t_distance\"9\n\tKNNResult\x12,\n\x07records\x18\x01 \x03(\x0b\x32\x1b.chroma.KNNProjectionRecord\"4\n\x0eKNNBatchResult\x12\"\n\x07results\x18\x01 \x03(\x0b\x32\x11.chroma.KNNResult2\xa1\x01\n\rQueryExecutor\x12\x31\n\x05\x43ount\x12\x11.chroma.CountPlan\x1a\x13.chroma.CountResult\"\x00\x12+\n\x03Get\x12\x0f.chroma.GetPlan\x1a\x11.chroma.GetResult\"\x00\x12\x30\n\x03KNN\x12\x0f.chroma.KNNPlan\x1a\x16.chroma.KNNBatchResult\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#chromadb/proto/query_executor.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\"\xd0\x01\n\x0cScanOperator\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0e\n\x06knn_id\x18\x02 \x01(\t\x12\x13\n\x0bmetadata_id\x18\x03 \x01(\t\x12\x11\n\trecord_id\x18\x04 \x01(\t\x12\x1c\n\x03knn\x18\x05 \x01(\x0b\x32\x0f.chroma.Segment\x12!\n\x08metadata\x18\x06 \x01(\x0b\x32\x0f.chroma.Segment\x12\x1f\n\x06record\x18\x07 \x01(\x0b\x32\x0f.chroma.Segment\"\xaf\x01\n\x0e\x46ilterOperator\x12!\n\x03ids\x18\x01 \x01(\x0b\x32\x0f.chroma.UserIdsH\x00\x88\x01\x01\x12!\n\x05where\x18\x02 \x01(\x0b\x32\r.chroma.WhereH\x01\x88\x01\x01\x12\x32\n\x0ewhere_document\x18\x03 \x01(\x0b\x32\x15.chroma.WhereDocumentH\x02\x88\x01\x01\x42\x06\n\x04_idsB\x08\n\x06_whereB\x11\n\x0f_where_document\"@\n\x0bKNNOperator\x12\"\n\nembeddings\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\r\n\x05\x66\x65tch\x18\x02 \x01(\r\";\n\rLimitOperator\x12\x0c\n\x04skip\x18\x01 \x01(\r\x12\x12\n\x05\x66\x65tch\x18\x02 \x01(\rH\x00\x88\x01\x01\x42\x08\n\x06_fetch\"K\n\x12ProjectionOperator\x12\x10\n\x08\x64ocument\x18\x01 \x01(\x08\x12\x11\n\tembedding\x18\x02 \x01(\x08\x12\x10\n\x08metadata\x18\x03 \x01(\x08\"Y\n\x15KNNProjectionOperator\x12.\n\nprojection\x18\x01 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\x12\x10\n\x08\x64istance\x18\x02 \x01(\x08\"/\n\tCountPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\"\x1c\n\x0b\x43ountResult\x12\r\n\x05\x63ount\x18\x01 \x01(\r\"\xab\x01\n\x07GetPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12$\n\x05limit\x18\x03 \x01(\x0b\x32\x15.chroma.LimitOperator\x12.\n\nprojection\x18\x04 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\"\xb4\x01\n\x10ProjectionRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x08\x64ocument\x18\x02 \x01(\tH\x00\x88\x01\x01\x12&\n\tembedding\x18\x03 \x01(\x0b\x32\x0e.chroma.VectorH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x0b\n\t_documentB\x0c\n\n_embeddingB\x0b\n\t_metadata\"6\n\tGetResult\x12)\n\x07records\x18\x01 \x03(\x0b\x32\x18.chroma.ProjectionRecord\"\xaa\x01\n\x07KNNPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12 \n\x03knn\x18\x03 \x01(\x0b\x32\x13.chroma.KNNOperator\x12\x31\n\nprojection\x18\x04 \x01(\x0b\x32\x1d.chroma.KNNProjectionOperator\"c\n\x13KNNProjectionRecord\x12(\n\x06record\x18\x01 \x01(\x0b\x32\x18.chroma.ProjectionRecord\x12\x15\n\x08\x64istance\x18\x02 \x01(\x02H\x00\x88\x01\x01\x42\x0b\n\t_distance\"9\n\tKNNResult\x12,\n\x07records\x18\x01 \x03(\x0b\x32\x1b.chroma.KNNProjectionRecord\"4\n\x0eKNNBatchResult\x12\"\n\x07results\x18\x01 \x03(\x0b\x32\x11.chroma.KNNResult2\xa1\x01\n\rQueryExecutor\x12\x31\n\x05\x43ount\x12\x11.chroma.CountPlan\x1a\x13.chroma.CountResult\"\x00\x12+\n\x03Get\x12\x0f.chroma.GetPlan\x1a\x11.chroma.GetResult\"\x00\x12\x30\n\x03KNN\x12\x0f.chroma.KNNPlan\x1a\x16.chroma.KNNBatchResult\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -22,36 +22,36 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_SCANOPERATOR']._serialized_start=76 - _globals['_SCANOPERATOR']._serialized_end=186 - _globals['_FILTEROPERATOR']._serialized_start=189 - _globals['_FILTEROPERATOR']._serialized_end=364 - _globals['_KNNOPERATOR']._serialized_start=366 - _globals['_KNNOPERATOR']._serialized_end=430 - _globals['_LIMITOPERATOR']._serialized_start=432 - _globals['_LIMITOPERATOR']._serialized_end=491 - _globals['_PROJECTIONOPERATOR']._serialized_start=493 - _globals['_PROJECTIONOPERATOR']._serialized_end=568 - _globals['_KNNPROJECTIONOPERATOR']._serialized_start=570 - _globals['_KNNPROJECTIONOPERATOR']._serialized_end=659 - _globals['_COUNTPLAN']._serialized_start=661 - _globals['_COUNTPLAN']._serialized_end=708 - _globals['_COUNTRESULT']._serialized_start=710 - _globals['_COUNTRESULT']._serialized_end=738 - _globals['_GETPLAN']._serialized_start=741 - _globals['_GETPLAN']._serialized_end=912 - _globals['_PROJECTIONRECORD']._serialized_start=915 - _globals['_PROJECTIONRECORD']._serialized_end=1095 - _globals['_GETRESULT']._serialized_start=1097 - _globals['_GETRESULT']._serialized_end=1151 - _globals['_KNNPLAN']._serialized_start=1154 - _globals['_KNNPLAN']._serialized_end=1324 - _globals['_KNNPROJECTIONRECORD']._serialized_start=1326 - _globals['_KNNPROJECTIONRECORD']._serialized_end=1425 - _globals['_KNNRESULT']._serialized_start=1427 - _globals['_KNNRESULT']._serialized_end=1484 - _globals['_KNNBATCHRESULT']._serialized_start=1486 - _globals['_KNNBATCHRESULT']._serialized_end=1538 - _globals['_QUERYEXECUTOR']._serialized_start=1541 - _globals['_QUERYEXECUTOR']._serialized_end=1702 + _globals['_SCANOPERATOR']._serialized_start=77 + _globals['_SCANOPERATOR']._serialized_end=285 + _globals['_FILTEROPERATOR']._serialized_start=288 + _globals['_FILTEROPERATOR']._serialized_end=463 + _globals['_KNNOPERATOR']._serialized_start=465 + _globals['_KNNOPERATOR']._serialized_end=529 + _globals['_LIMITOPERATOR']._serialized_start=531 + _globals['_LIMITOPERATOR']._serialized_end=590 + _globals['_PROJECTIONOPERATOR']._serialized_start=592 + _globals['_PROJECTIONOPERATOR']._serialized_end=667 + _globals['_KNNPROJECTIONOPERATOR']._serialized_start=669 + _globals['_KNNPROJECTIONOPERATOR']._serialized_end=758 + _globals['_COUNTPLAN']._serialized_start=760 + _globals['_COUNTPLAN']._serialized_end=807 + _globals['_COUNTRESULT']._serialized_start=809 + _globals['_COUNTRESULT']._serialized_end=837 + _globals['_GETPLAN']._serialized_start=840 + _globals['_GETPLAN']._serialized_end=1011 + _globals['_PROJECTIONRECORD']._serialized_start=1014 + _globals['_PROJECTIONRECORD']._serialized_end=1194 + _globals['_GETRESULT']._serialized_start=1196 + _globals['_GETRESULT']._serialized_end=1250 + _globals['_KNNPLAN']._serialized_start=1253 + _globals['_KNNPLAN']._serialized_end=1423 + _globals['_KNNPROJECTIONRECORD']._serialized_start=1425 + _globals['_KNNPROJECTIONRECORD']._serialized_end=1524 + _globals['_KNNRESULT']._serialized_start=1526 + _globals['_KNNRESULT']._serialized_end=1583 + _globals['_KNNBATCHRESULT']._serialized_start=1585 + _globals['_KNNBATCHRESULT']._serialized_end=1637 + _globals['_QUERYEXECUTOR']._serialized_start=1640 + _globals['_QUERYEXECUTOR']._serialized_end=1801 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/query_executor_pb2.pyi b/chromadb/proto/query_executor_pb2.pyi index 53483f8445f..33b8bb63e1a 100644 --- a/chromadb/proto/query_executor_pb2.pyi +++ b/chromadb/proto/query_executor_pb2.pyi @@ -7,16 +7,22 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor class ScanOperator(_message.Message): - __slots__ = ["collection", "knn_id", "metadata_id", "record_id"] + __slots__ = ["collection", "knn_id", "metadata_id", "record_id", "knn", "metadata", "record"] COLLECTION_FIELD_NUMBER: _ClassVar[int] KNN_ID_FIELD_NUMBER: _ClassVar[int] METADATA_ID_FIELD_NUMBER: _ClassVar[int] RECORD_ID_FIELD_NUMBER: _ClassVar[int] + KNN_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + RECORD_FIELD_NUMBER: _ClassVar[int] collection: _chroma_pb2.Collection knn_id: str metadata_id: str record_id: str - def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., knn_id: _Optional[str] = ..., metadata_id: _Optional[str] = ..., record_id: _Optional[str] = ...) -> None: ... + knn: _chroma_pb2.Segment + metadata: _chroma_pb2.Segment + record: _chroma_pb2.Segment + def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., knn_id: _Optional[str] = ..., metadata_id: _Optional[str] = ..., record_id: _Optional[str] = ..., knn: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ..., metadata: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ..., record: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ...) -> None: ... class FilterOperator(_message.Message): __slots__ = ["ids", "where", "where_document"] diff --git a/idl/chromadb/proto/query_executor.proto b/idl/chromadb/proto/query_executor.proto index 0149767217c..53aa37a8b52 100644 --- a/idl/chromadb/proto/query_executor.proto +++ b/idl/chromadb/proto/query_executor.proto @@ -6,9 +6,15 @@ import "chromadb/proto/chroma.proto"; message ScanOperator { Collection collection = 1; + // Deprecated string knn_id = 2; + // Deprecated string metadata_id = 3; + // Deprecated string record_id = 4; + Segment knn = 5; + Segment metadata = 6; + Segment record = 7; } message FilterOperator { diff --git a/rust/types/src/segment.rs b/rust/types/src/segment.rs index 29fba56af84..8858365554a 100644 --- a/rust/types/src/segment.rs +++ b/rust/types/src/segment.rs @@ -6,6 +6,7 @@ use crate::chroma_proto; use chroma_error::{ChromaError, ErrorCodes}; use std::{collections::HashMap, str::FromStr}; use thiserror::Error; +use tonic::Status; use uuid::Uuid; /// SegmentUuid is a wrapper around Uuid to provide a type for the segment id. @@ -106,6 +107,12 @@ impl ChromaError for SegmentConversionError { } } +impl From for Status { + fn from(value: SegmentConversionError) -> Self { + Status::invalid_argument(value.to_string()) + } +} + impl TryFrom for Segment { type Error = SegmentConversionError; diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index bda9641b300..da550483fc7 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -15,7 +15,6 @@ use futures::{stream, StreamExt, TryStreamExt}; use tokio::signal::unix::{signal, SignalKind}; use tonic::{transport::Server, Request, Response, Status}; use tracing::{trace_span, Instrument}; -use uuid::Uuid; use crate::{ config::QueryServiceConfig, @@ -25,7 +24,12 @@ use crate::{ fetch_log::FetchLogOperator, fetch_segment::FetchSegmentOperator, knn_projection::KnnProjectionOperator, }, - orchestration::{get::GetOrchestrator, knn::KnnOrchestrator, knn_filter::{KnnError, KnnFilterOrchestrator}, CountQueryOrchestrator}, + orchestration::{ + get::GetOrchestrator, + knn::KnnOrchestrator, + knn_filter::{KnnError, KnnFilterOrchestrator}, + CountQueryOrchestrator, + }, }, log::log::Log, sysdb::sysdb::SysDb, @@ -143,14 +147,12 @@ impl WorkerServer { let collection_uuid = CollectionUuid::from_str(&collection.id) .map_err(|_| Status::invalid_argument("Invalid Collection UUID"))?; - let vector_uuid = SegmentUuid::from_str(&scan.knn_id) - .map_err(|_| Status::invalid_argument("Invalid UUID for Vector segment"))?; - - let metadata_uuid = SegmentUuid::from_str(&scan.metadata_id) - .map_err(|_| Status::invalid_argument("Invalid UUID for Metadata segment"))?; - - let record_uuid = SegmentUuid::from_str(&scan.record_id) - .map_err(|_| Status::invalid_argument("Invalid UUID for Record segment"))?; + let metadata_uuid = + SegmentUuid::from_str(&scan.metadata.map(|seg| seg.id).unwrap_or(scan.metadata_id))?; + let record_uuid = + SegmentUuid::from_str(&scan.record.map(|seg| seg.id).unwrap_or(scan.record_id))?; + let vector_uuid = + SegmentUuid::from_str(&scan.knn.map(|seg| seg.id).unwrap_or(scan.knn_id))?; Ok(( FetchLogOperator { @@ -189,8 +191,7 @@ impl WorkerServer { let count_orchestrator = CountQueryOrchestrator::new( self.clone_system()?, - &Uuid::parse_str(&scan.metadata_id) - .map_err(|e| Status::invalid_argument(e.to_string()))?, + &SegmentUuid::from_str(&scan.metadata.map(|seg| seg.id).unwrap_or(scan.metadata_id))?.0, &CollectionUuid::from_str(&collection.id) .map_err(|e| Status::invalid_argument(e.to_string()))?, self.log.clone(), @@ -422,6 +423,8 @@ impl chroma_proto::debug_server::Debug for WorkerServer { #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::execution::dispatcher; use crate::log::log::InMemoryLog; @@ -465,9 +468,10 @@ mod tests { } fn scan() -> chroma_proto::ScanOperator { + let collection_id = Uuid::new_v4().to_string(); chroma_proto::ScanOperator { collection: Some(chroma_proto::Collection { - id: Uuid::new_v4().to_string(), + id: collection_id.clone(), name: "Test-Collection".to_string(), configuration_json_str: String::new(), metadata: None, @@ -477,9 +481,36 @@ mod tests { log_position: 0, version: 0, }), - knn_id: Uuid::new_v4().to_string(), - metadata_id: Uuid::new_v4().to_string(), - record_id: Uuid::new_v4().to_string(), + // Deprecated + knn_id: "".to_string(), + // Deprecated + metadata_id: "".to_string(), + // Deprecated + record_id: "".to_string(), + knn: Some(chroma_proto::Segment { + id: Uuid::new_v4().to_string(), + r#type: "urn:chroma:segment/vector/hnsw-distributed".to_string(), + scope: 0, + collection: collection_id.clone(), + metadata: None, + file_paths: HashMap::new(), + }), + metadata: Some(chroma_proto::Segment { + id: Uuid::new_v4().to_string(), + r#type: "urn:chroma:segment/metadata/blockfile".to_string(), + scope: 1, + collection: collection_id.clone(), + metadata: None, + file_paths: HashMap::new(), + }), + record: Some(chroma_proto::Segment { + id: Uuid::new_v4().to_string(), + r#type: "urn:chroma:segment/record/blockfile".to_string(), + scope: 2, + collection: collection_id.clone(), + metadata: None, + file_paths: HashMap::new(), + }), } } @@ -509,7 +540,19 @@ mod tests { let response = executor.count(request).await; assert_eq!(response.unwrap_err().code(), tonic::Code::NotFound); - scan_operator.metadata_id = "invalid_segment_id".to_string(); + scan_operator.metadata = Some(chroma_proto::Segment { + id: "invalid-metadata-segment-id".to_string(), + r#type: "urn:chroma:segment/metadata/blockfile".to_string(), + scope: 1, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let request = chroma_proto::CountPlan { scan: Some(scan_operator.clone()), }; @@ -567,13 +610,13 @@ mod tests { assert_eq!(response.unwrap_err().code(), tonic::Code::InvalidArgument); scan_operator.collection = Some(chroma_proto::Collection { - id: "Invalid-Collection-ID".to_string(), - name: "Broken-Collection".to_string(), + id: "invalid-collection-iD".to_string(), + name: "broken-collection".to_string(), configuration_json_str: String::new(), metadata: None, dimension: None, - tenant: "Test-Tenant".to_string(), - database: "Test-Database".to_string(), + tenant: "test-tenant".to_string(), + database: "test-database".to_string(), log_position: 0, version: 0, }); @@ -601,7 +644,9 @@ mod tests { assert_eq!(response.unwrap_err().code(), tonic::Code::InvalidArgument); } - fn gen_knn_request(mut scan_operator: Option) -> chroma_proto::KnnPlan { + fn gen_knn_request( + mut scan_operator: Option, + ) -> chroma_proto::KnnPlan { if scan_operator.is_none() { scan_operator = Some(scan()); } @@ -716,13 +761,13 @@ mod tests { async fn validate_knn_plan_scan_collection() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan = scan(); - scan.collection.as_mut().unwrap().id = "Invalid-Collection-ID".to_string(); + scan.collection.as_mut().unwrap().id = "invalid-collection-id".to_string(); let response = executor.knn(gen_knn_request(Some(scan))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); assert!( - err.message().to_lowercase().contains("collection uuid"), + err.message().to_lowercase().contains("uuid"), "{}", err.message() ); @@ -733,13 +778,25 @@ mod tests { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); // invalid vector uuid let mut scan_operator = scan(); - scan_operator.knn_id = "invalid_segment_id".to_string(); + scan_operator.knn = Some(chroma_proto::Segment { + id: "invalid-knn-segment-id".to_string(), + r#type: "urn:chroma:segment/vector/hnsw-distributed".to_string(), + scope: 0, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let response = executor.knn(gen_knn_request(Some(scan_operator))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); assert!( - err.message().to_lowercase().contains("vector"), + err.message().to_lowercase().contains("uuid"), "{}", err.message() ); @@ -749,13 +806,25 @@ mod tests { async fn validate_knn_plan_scan_record() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan_operator = scan(); - scan_operator.record_id = "invalid_record_id".to_string(); + scan_operator.record = Some(chroma_proto::Segment { + id: "invalid-record-segment-id".to_string(), + r#type: "urn:chroma:segment/record/blockfile".to_string(), + scope: 2, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let response = executor.knn(gen_knn_request(Some(scan_operator))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); assert!( - err.message().to_lowercase().contains("record"), + err.message().to_lowercase().contains("uuid"), "{}", err.message() ); @@ -765,13 +834,25 @@ mod tests { async fn validate_knn_plan_scan_metadata() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan_operator = scan(); - scan_operator.metadata_id = "invalid_metadata_id".to_string(); + scan_operator.metadata = Some(chroma_proto::Segment { + id: "invalid-metadata-segment-id".to_string(), + r#type: "urn:chroma:segment/metadata/blockfile".to_string(), + scope: 1, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let response = executor.knn(gen_knn_request(Some(scan_operator))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); assert!( - err.message().to_lowercase().contains("metadata"), + err.message().to_lowercase().contains("uuid"), "{}", err.message() );