Skip to content

Commit

Permalink
Implement mock sysdb endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Sicheng Pan committed Dec 10, 2024
1 parent a140fbd commit db74461
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
28 changes: 27 additions & 1 deletion chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
DeleteSegmentResponse,
GetCollectionsRequest,
GetCollectionsResponse,
GetCollectionWithSegmentsRequest,
GetCollectionWithSegmentsResponse,
GetDatabaseRequest,
GetDatabaseResponse,
GetSegmentsRequest,
Expand All @@ -46,7 +48,7 @@
)
import grpc
from google.protobuf.empty_pb2 import Empty
from chromadb.types import Collection, Metadata, Segment
from chromadb.types import Collection, Metadata, Segment, SegmentScope


class GrpcMockSysDB(SysDBServicer, Component):
Expand Down Expand Up @@ -370,6 +372,30 @@ def GetCollections(
]
)

@overrides(check_signature=False)
def GetCollectionWithSegments(
self, request: GetCollectionWithSegmentsRequest, context: grpc.ServicerContext
) -> GetCollectionWithSegmentsResponse:
allCollections = {}
for tenant, databases in self._tenants_to_databases_to_collections.items():
for database, collections in databases.items():
allCollections.update(collections)
print(
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
)
collection = allCollections.get(request.id, None)
if collection is None:
context.abort(grpc.StatusCode.NOT_FOUND, f"Collection with id {request.id} not found")
collection_unique_key = f"{collection.tenant}:{collection.database}:{request.id}"
segments = [self._segments[id] for id in self._collection_to_segments[collection_unique_key]]
if {segment["scope"] for segment in segments} != {SegmentScope.METADATA, SegmentScope.RECORD, SegmentScope.VECTOR}:
context.abort(grpc.StatusCode.INTERNAL, f"Incomplete segments for collection {collection}: {segments}")

return GetCollectionWithSegmentsResponse(
collection=to_proto_collection(collection),
segments=[to_proto_segment(segment) for segment in segments]
)

@overrides(check_signature=False)
def UpdateCollection(
self, request: UpdateCollectionRequest, context: grpc.ServicerContext
Expand Down
16 changes: 13 additions & 3 deletions chromadb/test/db/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,17 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None:
segments_created_with_collection = []
for collection in sample_collections:
logger.debug(f"Creating collection: {collection.name}")
segment = sample_segment(collection_id=collection.id)
segments_created_with_collection.append(segment)
segments = [
sample_segment(collection_id=collection.id, scope=SegmentScope.METADATA),
sample_segment(collection_id=collection.id, scope=SegmentScope.RECORD),
sample_segment(collection_id=collection.id, scope=SegmentScope.VECTOR),
]
segments_created_with_collection.extend(segments)
sysdb.create_collection(
id=collection.id,
name=collection.name,
configuration=collection.get_configuration(),
segments=[segment],
segments=segments,
metadata=collection["metadata"],
dimension=collection["dimension"],
)
Expand Down Expand Up @@ -223,6 +227,12 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None:
result = sysdb.get_collections(id=collection["id"])
assert result == [collection]

# Verify segment information
for collection in sample_collections:
collection_with_segments_result = sysdb.get_collection_with_segments(collection.id)
assert collection_with_segments_result["collection"] == collection
assert all([segment["id"] == collection.id for segment in collection_with_segments_result["segments"]])

# Delete
c1 = sample_collections[0]
sysdb.delete_collection(id=c1.id)
Expand Down

0 comments on commit db74461

Please sign in to comment.