From 32509ee80916cb27520dad11a597e0f5dfd6f7d5 Mon Sep 17 00:00:00 2001 From: nicolasgere Date: Fri, 26 Apr 2024 15:51:34 -0700 Subject: [PATCH] [ENH]: fix sysdb client grpc (#2068) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Fix issue with sysdb grpc client --- chromadb/db/impl/grpc/client.py | 34 +++++++++++++++++++++++++-------- chromadb/db/impl/grpc/server.py | 25 ++++++++++++------------ 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index 92b0a3f8d42..88693921346 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -246,14 +246,32 @@ def get_collections( offset: Optional[int] = None, ) -> Sequence[Collection]: # TODO: implement limit and offset in the gRPC service - request = GetCollectionsRequest( - id=id.hex if id else None, - name=name, - tenant=tenant, - database=database, - limit=limit, - offset=offset, - ) + request = None + if id is not None: + request = GetCollectionsRequest( + id=id.hex, + limit=limit, + offset=offset, + ) + if name is not None: + if tenant is None and database is None: + raise ValueError( + "If name is specified, tenant and database must also be specified in order to uniquely identify the collection" + ) + request = GetCollectionsRequest( + name=name, + tenant=tenant, + database=database, + limit=limit, + offset=offset, + ) + if id is None and name is None: + request = GetCollectionsRequest( + tenant=tenant, + database=database, + limit=limit, + offset=offset, + ) response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request) results: List[Collection] = [] for collection in response.collections: diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index cdec22eff0d..72da066ec07 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -359,20 +359,19 @@ def GetCollections( target_id = UUID(hex=request.id) if request.HasField("id") else None target_name = request.name if request.HasField("name") else None - tenant = request.tenant - database = request.database - if tenant not in self._tenants_to_databases_to_collections: - return GetCollectionsResponse( - status=proto.Status(code=404, reason=f"Tenant {tenant} not found") - ) - if database not in self._tenants_to_databases_to_collections[tenant]: - return GetCollectionsResponse( - status=proto.Status(code=404, reason=f"Database {database} not found") - ) - collections = self._tenants_to_databases_to_collections[tenant][database] - + allCollections = {} + for tenant, databases in self._tenants_to_databases_to_collections.items(): + for database, collections in databases.items(): + if request.tenant != "" and tenant != request.tenant: + continue + if request.database != "" and database != request.database: + continue + allCollections.update(collections) + print( + f"Tenant: {tenant}, Database: {database}, Collections: {collections}" + ) found_collections = [] - for collection in collections.values(): + for collection in allCollections.values(): if target_id and collection["id"] != target_id: continue if target_name and collection["name"] != target_name: