From 5519de7a25afd8d95663acd64895ebc6cc1bde30 Mon Sep 17 00:00:00 2001 From: Sicheng Pan Date: Mon, 9 Dec 2024 15:56:37 -0800 Subject: [PATCH] Implement mock sysdb endpoint --- chromadb/db/impl/grpc/server.py | 28 +++++++++++++++++++++++++++- chromadb/test/db/test_system.py | 6 ++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index 8edbc0eeccb8..b6948ee9615f 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -28,6 +28,8 @@ DeleteSegmentResponse, GetCollectionsRequest, GetCollectionsResponse, + GetCollectionWithSegmentsRequest, + GetCollectionWithSegmentsResponse, GetDatabaseRequest, GetDatabaseResponse, GetSegmentsRequest, @@ -46,7 +48,7 @@ ) import grpc from google.protobuf.empty_pb2 import Empty -from chromadb.types import Collection, Metadata, Segment +from chromadb.types import Collection, Metadata, Segment, SegmentScope class GrpcMockSysDB(SysDBServicer, Component): @@ -370,6 +372,30 @@ def GetCollections( ] ) + @overrides(check_signature=False) + def GetCollectionWithSegments( + self, request: GetCollectionWithSegmentsRequest, context: grpc.ServicerContext + ) -> GetCollectionWithSegmentsResponse: + allCollections = {} + for tenant, databases in self._tenants_to_databases_to_collections.items(): + for database, collections in databases.items(): + allCollections.update(collections) + print( + f"Tenant: {tenant}, Database: {database}, Collections: {collections}" + ) + collection = allCollections.get(request.id, None) + if collection is None: + context.abort(grpc.StatusCode.NOT_FOUND, f"Collection with id {request.id} not found") + collection_unique_key = f"{collection.tenant}:{collection.database}:{request.id}" + segments = [self._segments[id] for id in self._collection_to_segments[collection_unique_key]] + if {segment["scope"] for segment in segments} != {SegmentScope.METADATA, SegmentScope.RECORD, SegmentScope.VECTOR}: + context.abort(grpc.StatusCode.INTERNAL, f"Incomplete segments for collection {collection}: {segments}") + + return GetCollectionWithSegmentsResponse( + collection=to_proto_collection(collection), + segments=[to_proto_segment(segment) for segment in segments] + ) + @overrides(check_signature=False) def UpdateCollection( self, request: UpdateCollectionRequest, context: grpc.ServicerContext diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index f49a17226cb0..e09caab26b92 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -223,6 +223,12 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None: result = sysdb.get_collections(id=collection["id"]) assert result == [collection] + # Verify segment information + for collection in sample_collections: + collection_with_segments_result = sysdb.get_collection_with_segments(collection.id) + assert collection_with_segments_result["collection"] == collection + assert all([segment["id"] == collection.id for segment in collection_with_segments_result["segments"]]) + # Delete c1 = sample_collections[0] sysdb.delete_collection(id=c1.id)