From 8dad8e0264fbef357f020ddc03568285ed5b5e07 Mon Sep 17 00:00:00 2001 From: hammadb Date: Wed, 18 Oct 2023 21:23:38 -0700 Subject: [PATCH] Simplify --- chromadb/api/__init__.py | 130 ++------------ chromadb/api/client.py | 114 ++++++------ chromadb/api/fastapi.py | 38 ++-- chromadb/api/models/Collection.py | 2 - chromadb/api/segment.py | 40 ++--- chromadb/db/impl/grpc/client.py | 30 +++- chromadb/db/impl/grpc/server.py | 115 ++++++++++-- chromadb/db/impl/sqlite.py | 6 +- chromadb/db/mixins/sysdb.py | 94 +++++++++- chromadb/db/system.py | 20 ++- .../sysdb/00004-tenants-databases.sqlite.sql | 28 ++- chromadb/proto/coordinator_pb2.py | 52 +++--- chromadb/proto/coordinator_pb2.pyi | 34 +++- chromadb/proto/coordinator_pb2_grpc.py | 45 +++++ chromadb/test/db/test_system.py | 170 ++++++++++++++++++ idl/chromadb/proto/coordinator.proto | 15 +- 16 files changed, 649 insertions(+), 284 deletions(-) diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 2287bd929d2..28a62fcb5be 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -419,7 +419,19 @@ def clear_system_cache() -> None: pass -class ServerAPI(BaseAPI, Component): +class AdminAPI(ABC): + @abstractmethod + def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + """Create a new database. + + Args: + database: The name of the database to create. + + """ + pass + + +class ServerAPI(BaseAPI, AdminAPI, Component): """An API instance that extends the relevant Base API methods by passing in a tenant and database. This is the root component of the Chroma System""" @@ -473,8 +485,6 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> None: pass @@ -487,117 +497,3 @@ def delete_collection( database: str = DEFAULT_DATABASE, ) -> None: pass - - # - # ITEM METHODS - # - - @abstractmethod - @override - def _add( - self, - ids: IDs, - collection_id: UUID, - embeddings: Embeddings, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> bool: - pass - - @abstractmethod - @override - def _update( - self, - collection_id: UUID, - ids: IDs, - embeddings: Optional[Embeddings] = None, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> bool: - pass - - @abstractmethod - @override - def _upsert( - self, - collection_id: UUID, - ids: IDs, - embeddings: Embeddings, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> bool: - pass - - @abstractmethod - @override - def _count( - self, - collection_id: UUID, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> int: - pass - - @abstractmethod - @override - def _peek( - self, - collection_id: UUID, - n: int = 10, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> GetResult: - pass - - @abstractmethod - @override - def _get( - self, - collection_id: UUID, - ids: Optional[IDs] = None, - where: Optional[Where] = {}, - sort: Optional[str] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - page: Optional[int] = None, - page_size: Optional[int] = None, - where_document: Optional[WhereDocument] = {}, - include: Include = ["embeddings", "metadatas", "documents"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> GetResult: - pass - - @abstractmethod - @override - def _delete( - self, - collection_id: UUID, - ids: Optional[IDs], - where: Optional[Where] = {}, - where_document: Optional[WhereDocument] = {}, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> IDs: - pass - - @abstractmethod - @override - def _query( - self, - collection_id: UUID, - query_embeddings: Embeddings, - n_results: int = 10, - where: Where = {}, - where_document: WhereDocument = {}, - include: Include = ["embeddings", "metadatas", "documents", "distances"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> QueryResult: - pass diff --git a/chromadb/api/client.py b/chromadb/api/client.py index c007495705b..24b087cc30c 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -2,7 +2,7 @@ from uuid import UUID from overrides import override -from chromadb.api import ClientAPI, ServerAPI +from chromadb.api import AdminAPI, ClientAPI, ServerAPI from chromadb.api.types import ( CollectionMetadata, Documents, @@ -23,51 +23,17 @@ import chromadb.utils.embedding_functions as ef -class Client(ClientAPI): - """A client for Chroma. This is the main entrypoint for interacting with Chroma. - A client internally stores its tenant and database and proxies calls to a - Server API instance of Chroma. It treats the Server API and corresponding System - as a singleton, so multiple clients connecting to the same resource will share the - same API instance. - - Client implementations should be implement their own API-caching strategies. - """ - - tenant: str = DEFAULT_TENANT - database: str = DEFAULT_DATABASE - +class SharedSystemClient: _identifer_to_system: ClassVar[Dict[str, System]] = {} _identifier: str - _server: ServerAPI # region Initialization - def __new__( - cls, - tenant: str = "default", - database: str = "default", - settings: Settings = Settings(), - ) -> "Client": - identifier = cls._get_identifier_from_settings(settings) - cls._create_system_if_not_exists(identifier, settings) - instance = super().__new__(cls) - return instance - def __init__( self, - tenant: str = "default", - database: str = "default", settings: Settings = Settings(), ) -> None: - self.tenant = tenant - self.database = database - self._identifier = self._get_identifier_from_settings(settings) - - # Get the root system component we want to interact with - self._server = self._system.instance(ServerAPI) - - # Submit event for a client start - telemetry_client = self._system.instance(Telemetry) - telemetry_client.capture(ClientStartEvent()) + self._identifier = SharedSystemClient._get_identifier_from_settings(settings) + SharedSystemClient._create_system_if_not_exists(self._identifier, settings) @classmethod def _create_system_if_not_exists( @@ -116,13 +82,48 @@ def _get_identifier_from_settings(settings: Settings) -> str: return identifier @staticmethod - @override def clear_system_cache() -> None: - Client._identifer_to_system = {} + SharedSystemClient._identifer_to_system = {} @property def _system(self) -> System: - return self._identifer_to_system[self._identifier] + return SharedSystemClient._identifer_to_system[self._identifier] + + # endregion + + +class Client(SharedSystemClient, ClientAPI): + """A client for Chroma. This is the main entrypoint for interacting with Chroma. + A client internally stores its tenant and database and proxies calls to a + Server API instance of Chroma. It treats the Server API and corresponding System + as a singleton, so multiple clients connecting to the same resource will share the + same API instance. + + Client implementations should be implement their own API-caching strategies. + """ + + tenant: str = DEFAULT_TENANT + database: str = DEFAULT_DATABASE + + _server: ServerAPI + + # region Initialization + def __init__( + self, + tenant: str = "default", + database: str = "default", + settings: Settings = Settings(), + ) -> None: + super().__init__(settings=settings) + self.tenant = tenant + self.database = database + + # Get the root system component we want to interact with + self._server = self._system.instance(ServerAPI) + + # Submit event for a client start + telemetry_client = self._system.instance(Telemetry) + telemetry_client.capture(ClientStartEvent()) # endregion @@ -150,6 +151,7 @@ def create_collection( embedding_function=embedding_function, tenant=self.tenant, database=self.database, + get_or_create=get_or_create, ) @override @@ -191,8 +193,6 @@ def _modify( id=id, new_name=new_name, new_metadata=new_metadata, - tenant=self.tenant, - database=self.database, ) @override @@ -225,8 +225,6 @@ def _add( embeddings=embeddings, metadatas=metadatas, documents=documents, - tenant=self.tenant, - database=self.database, ) @override @@ -244,8 +242,6 @@ def _update( embeddings=embeddings, metadatas=metadatas, documents=documents, - tenant=self.tenant, - database=self.database, ) @override @@ -263,16 +259,12 @@ def _upsert( embeddings=embeddings, metadatas=metadatas, documents=documents, - tenant=self.tenant, - database=self.database, ) @override def _count(self, collection_id: UUID) -> int: return self._server._count( collection_id=collection_id, - tenant=self.tenant, - database=self.database, ) @override @@ -280,8 +272,6 @@ def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: return self._server._peek( collection_id=collection_id, n=n, - tenant=self.tenant, - database=self.database, ) @override @@ -309,8 +299,6 @@ def _get( page_size=page_size, where_document=where_document, include=include, - tenant=self.tenant, - database=self.database, ) def _delete( @@ -325,8 +313,6 @@ def _delete( ids=ids, where=where, where_document=where_document, - tenant=self.tenant, - database=self.database, ) @override @@ -346,8 +332,6 @@ def _query( where=where, where_document=where_document, include=include, - tenant=self.tenant, - database=self.database, ) @override @@ -380,3 +364,15 @@ def set_tenant(self, tenant: str) -> None: self.tenant = tenant # endregion + + +class AdminClient(AdminAPI, SharedSystemClient): + _server: ServerAPI + + def __init__(self, settings: Settings = Settings()) -> None: + super().__init__(settings) + self._server = self._system.instance(ServerAPI) + + @override + def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + return self._server.create_database(name=name, tenant=tenant) diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 2ee75612c23..784d9e79dbb 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -134,6 +134,19 @@ def heartbeat(self) -> int: raise_chroma_error(resp) return int(resp.json()["nanosecond heartbeat"]) + @override + def create_database( + self, + name: str, + tenant: str = DEFAULT_TENANT, + ) -> None: + """Creates a database""" + resp = self._session.post( + self._api_url + "/databases", + data=json.dumps({"name": name, "tenant": tenant}), + ) + raise_chroma_error(resp) + @override def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE @@ -205,7 +218,12 @@ def get_or_create_collection( database: str = DEFAULT_DATABASE, ) -> Collection: return self.create_collection( - name, metadata, embedding_function, get_or_create=True + name, + metadata, + embedding_function, + get_or_create=True, + tenant=tenant, + database=database, ) @override @@ -214,8 +232,6 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> None: """Updates a collection""" resp = self._session.put( @@ -239,8 +255,6 @@ def delete_collection( def _count( self, collection_id: UUID, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> int: """Returns the number of embeddings in the database""" resp = self._session.get( @@ -254,8 +268,6 @@ def _peek( self, collection_id: UUID, n: int = 10, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> GetResult: return self._get( collection_id, @@ -276,8 +288,6 @@ def _get( page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> GetResult: if page and page_size: offset = (page - 1) * page_size @@ -314,8 +324,6 @@ def _delete( ids: Optional[IDs] = None, where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> IDs: """Deletes embeddings from the database""" resp = self._session.post( @@ -359,8 +367,6 @@ def _add( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: """ Adds a batch of embeddings to the database @@ -380,8 +386,6 @@ def _update( embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: """ Updates a batch of embeddings in the database @@ -403,8 +407,6 @@ def _upsert( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: """ Upserts a batch of embeddings in the database @@ -427,8 +429,6 @@ def _query( where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents", "distances"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> QueryResult: """Gets the nearest neighbors of a single embedding""" resp = self._session.post( diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index d1f4b296712..f605d9d9d84 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -43,8 +43,6 @@ class Collection(BaseModel): _client: "ServerAPI" = PrivateAttr() _embedding_function: Optional[EmbeddingFunction] = PrivateAttr() - # TODO: STORE THE TENANT AND NAMESPACE IN THE COLLECTION OBJECT - def __init__( self, client: "ServerAPI", diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 27b53ea41ef..37887d5432a 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -94,6 +94,14 @@ def __init__(self, system: System): def heartbeat(self) -> int: return int(time.time_ns()) + @override + def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + self._sysdb.create_database( + id=uuid4(), + name=name, + tenant=tenant, + ) + # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. @@ -121,6 +129,8 @@ def create_collection( metadata=metadata, dimension=None, get_or_create=get_or_create, + tenant=tenant, + database=database, ) if created: @@ -158,6 +168,8 @@ def get_or_create_collection( metadata=metadata, embedding_function=embedding_function, get_or_create=True, + tenant=tenant, + database=database, ) # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is @@ -171,7 +183,9 @@ def get_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - existing = self._sysdb.get_collections(name=name) + existing = self._sysdb.get_collections( + name=name, tenant=tenant, database=database + ) if existing: return Collection( @@ -191,7 +205,7 @@ def list_collections( database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: collections = [] - db_collections = self._sysdb.get_collections() + db_collections = self._sysdb.get_collections(tenant=tenant, database=database) for db_collection in db_collections: collections.append( Collection( @@ -209,8 +223,6 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> None: if new_name: # backwards compatibility in naming requirements (for now) @@ -235,7 +247,9 @@ def delete_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: - existing = self._sysdb.get_collections(name=name) + existing = self._sysdb.get_collections( + name=name, tenant=tenant, database=database + ) if existing: self._sysdb.delete_collection(existing[0]["id"]) @@ -254,8 +268,6 @@ def _add( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) @@ -293,8 +305,6 @@ def _update( embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) @@ -334,8 +344,6 @@ def _upsert( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) @@ -370,8 +378,6 @@ def _get( page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> GetResult: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( @@ -439,8 +445,6 @@ def _delete( ids: Optional[IDs] = None, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> IDs: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( @@ -499,8 +503,6 @@ def _delete( def _count( self, collection_id: UUID, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> int: metadata_segment = self._manager.get_segment(collection_id, MetadataReader) return metadata_segment.count() @@ -514,8 +516,6 @@ def _query( where: Where = {}, where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> QueryResult: where = validate_where(where) if where is not None and len(where) > 0 else where where_document = ( @@ -612,8 +612,6 @@ def _peek( self, collection_id: UUID, n: int = 10, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> GetResult: return self._get(collection_id, limit=n) diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index 04d4302062a..fdf78ee957a 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -1,7 +1,7 @@ from typing import List, Optional, Sequence, Tuple, Union, cast from uuid import UUID from overrides import overrides -from chromadb.config import System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System from chromadb.db.base import NotFoundError, UniqueConstraintError from chromadb.db.system import SysDB from chromadb.proto.convert import ( @@ -13,6 +13,7 @@ ) from chromadb.proto.coordinator_pb2 import ( CreateCollectionRequest, + CreateDatabaseRequest, CreateSegmentRequest, DeleteCollectionRequest, DeleteSegmentRequest, @@ -71,6 +72,15 @@ def reset_state(self) -> None: self._sys_db_stub.ResetState(Empty()) return super().reset_state() + @overrides + def create_database( + self, id: UUID, name: str, tenant: str = DEFAULT_TENANT + ) -> None: + request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant) + response = self._sys_db_stub.CreateDatabase(request) + if response.status.code == 409: + raise UniqueConstraintError() + @overrides def create_segment(self, segment: Segment) -> None: proto_segment = to_proto_segment(segment) @@ -164,6 +174,8 @@ def create_collection( metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: request = CreateCollectionRequest( id=id.hex, @@ -171,6 +183,8 @@ def create_collection( metadata=to_proto_update_metadata(metadata) if metadata else None, dimension=dimension, get_or_create=get_or_create, + tenant=tenant, + database=database, ) response = self._sys_db_stub.CreateCollection(request) if response.status.code == 409: @@ -179,9 +193,13 @@ def create_collection( return collection, response.created @overrides - def delete_collection(self, id: UUID) -> None: + def delete_collection( + self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> None: request = DeleteCollectionRequest( id=id.hex, + tenant=tenant, + database=database, ) response = self._sys_db_stub.DeleteCollection(request) if response.status.code == 404: @@ -193,11 +211,15 @@ def get_collections( id: Optional[UUID] = None, topic: Optional[str] = None, name: Optional[str] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: request = GetCollectionsRequest( id=id.hex if id else None, topic=topic, name=name, + tenant=tenant, + database=database, ) response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request) results: List[Collection] = [] @@ -243,7 +265,9 @@ def update_collection( request.ClearField("metadata") request.reset_metadata = True - self._sys_db_stub.UpdateCollection(request) + response = self._sys_db_stub.UpdateCollection(request) + if response.status.code == 404: + raise NotFoundError() def reset_and_wait_for_ready(self) -> None: self._sys_db_stub.ResetState(Empty(), wait_for_ready=True) diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index 436a57fc167..b028ea5a3b4 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -16,6 +16,7 @@ from chromadb.proto.coordinator_pb2 import ( CreateCollectionRequest, CreateCollectionResponse, + CreateDatabaseRequest, CreateSegmentRequest, DeleteCollectionRequest, DeleteSegmentRequest, @@ -43,7 +44,9 @@ class GrpcMockSysDB(SysDBServicer, Component): _server_port: int _assignment_policy: CollectionAssignmentPolicy _segments: Dict[str, Segment] = {} - _collections: Dict[str, Collection] = {} + _tenants_to_databases_to_collections: Dict[ + str, Dict[str, Dict[str, Collection]] + ] = {} def __init__(self, system: System): self._server_port = system.settings.require("chroma_server_grpc_port") @@ -66,9 +69,31 @@ def stop(self) -> None: @overrides def reset_state(self) -> None: self._segments = {} - self._collections = {} + self._tenants_to_databases_to_collections = {} + # Create defaults + self._tenants_to_databases_to_collections["default"] = {} + self._tenants_to_databases_to_collections["default"]["default"] = {} return super().reset_state() + @overrides(check_signature=False) + def CreateDatabase( + self, request: CreateDatabaseRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + tenant = request.tenant + database = request.name + if tenant not in self._tenants_to_databases_to_collections: + return proto.ChromaResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database in self._tenants_to_databases_to_collections[tenant]: + return proto.ChromaResponse( + status=proto.Status( + code=409, reason=f"Database {database} already exists" + ) + ) + self._tenants_to_databases_to_collections[tenant][database] = {} + return proto.ChromaResponse(status=proto.Status(code=200)) + # We are forced to use check_signature=False because the generated proto code # does not have type annotations for the request and response objects. # TODO: investigate generating types for the request and response objects @@ -171,9 +196,47 @@ def CreateCollection( self, request: CreateCollectionRequest, context: grpc.ServicerContext ) -> CreateCollectionResponse: collection_name = request.name - matches = [ - c for c in self._collections.values() if c["name"] == collection_name - ] + tenant = request.tenant + database = request.database + if tenant not in self._tenants_to_databases_to_collections: + return CreateCollectionResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return CreateCollectionResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + + # Check if the collection already exists globally by id + for ( + search_tenant, + databases, + ) in self._tenants_to_databases_to_collections.items(): + for search_database, search_collections in databases.items(): + if request.id in search_collections: + if ( + search_tenant != request.tenant + or search_database != request.database + ): + return CreateCollectionResponse( + status=proto.Status( + code=409, + reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", + ) + ) + elif not request.get_or_create: + # If the id exists for this tenant and database, and we are not doing a get_or_create, then + # we should return a 409 + return CreateCollectionResponse( + status=proto.Status( + code=409, + reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", + ) + ) + + # Check if the collection already exists in this database by name + collections = self._tenants_to_databases_to_collections[tenant][database] + matches = [c for c in collections.values() if c["name"] == collection_name] assert len(matches) <= 1 if len(matches) > 0: if request.get_or_create: @@ -201,7 +264,7 @@ def CreateCollection( dimension=request.dimension, topic=self._assignment_policy.assign_collection(id), ) - self._collections[request.id] = new_collection + collections[request.id] = new_collection return CreateCollectionResponse( status=proto.Status(code=200), collection=to_proto_collection(new_collection), @@ -213,8 +276,19 @@ def DeleteCollection( self, request: DeleteCollectionRequest, context: grpc.ServicerContext ) -> proto.ChromaResponse: collection_id = request.id - if collection_id in self._collections: - del self._collections[collection_id] + tenant = request.tenant + database = request.database + if tenant not in self._tenants_to_databases_to_collections: + return proto.ChromaResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return proto.ChromaResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + collections = self._tenants_to_databases_to_collections[tenant][database] + if collection_id in collections: + del collections[collection_id] return proto.ChromaResponse(status=proto.Status(code=200)) else: return proto.ChromaResponse( @@ -231,8 +305,20 @@ def GetCollections( target_topic = request.topic if request.HasField("topic") else None target_name = request.name if request.HasField("name") else None + tenant = request.tenant + database = request.database + if tenant not in self._tenants_to_databases_to_collections: + return GetCollectionsResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return GetCollectionsResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + collections = self._tenants_to_databases_to_collections[tenant][database] + found_collections = [] - for collection in self._collections.values(): + for collection in collections.values(): if target_id and collection["id"] != target_id: continue if target_topic and collection["topic"] != target_topic: @@ -251,14 +337,21 @@ def UpdateCollection( self, request: UpdateCollectionRequest, context: grpc.ServicerContext ) -> proto.ChromaResponse: id_to_update = UUID(request.id) - if id_to_update.hex not in self._collections: + # Find the collection with this id + collections = {} + for tenant, databases in self._tenants_to_databases_to_collections.items(): + for database, maybe_collections in databases.items(): + if id_to_update.hex in maybe_collections: + collections = maybe_collections + + if id_to_update.hex not in collections: return proto.ChromaResponse( status=proto.Status( code=404, reason=f"Collection {id_to_update} not found" ) ) else: - collection = self._collections[id_to_update.hex] + collection = collections[id_to_update.hex] if request.HasField("topic"): collection["topic"] = request.topic if request.HasField("name"): diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index aed14deb8e2..10036d3c239 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -4,7 +4,6 @@ import chromadb.db.base as base from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue from chromadb.db.mixins.sysdb import SqlSysDB -from chromadb.utils.delete_file import delete_file import sqlite3 from overrides import override import pypika @@ -139,8 +138,9 @@ def reset_state(self) -> None: for row in cur.fetchall(): cur.execute(f"DROP TABLE IF EXISTS {row[0]}") self._conn_pool.close() - if self._is_persistent: - delete_file(self._db_file) + # TODO: clean this up ---- I don't think its correct or needed + # if self._is_persistent: + # delete_file(self._db_file) self.start() super().reset_state() diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index d105918e700..607d8dbea62 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -4,7 +4,7 @@ from pypika import Table, Column from itertools import groupby -from chromadb.config import System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System from chromadb.db.base import ( Cursor, SqlDB, @@ -41,6 +41,38 @@ def start(self) -> None: super().start() self._producer = self._system.instance(Producer) + @override + def create_database( + self, id: UUID, name: str, tenant: str = DEFAULT_TENANT + ) -> None: + with self.tx() as cur: + # Get the tenant id for the tenant name and then insert the database with the id, name and tenant id + databases = Table("databases") + tenants = Table("tenants") + insert_database = ( + self.querybuilder() + .into(databases) + .columns(databases.id, databases.name, databases.tenant_id) + .insert( + ParameterValue(self.uuid_to_db(id)), + ParameterValue(name), + self.querybuilder() + .select(tenants.id) + .from_(tenants) + .where(tenants.id == ParameterValue(tenant)), + ) + ) + sql, params = get_sql(insert_database, self.parameter_format()) + try: + cur.execute(sql, params) + # TODO: database doesn't exist + # TODO: tenant doesn't exist + # TODO: implement unique constraint error lol... + except self.unique_constraint_error() as e: + raise UniqueConstraintError( + f"Database {name} already exists for tenant {tenant}" + ) from e + @override def create_segment(self, segment: Segment) -> None: with self.tx() as cur: @@ -88,11 +120,13 @@ def create_collection( metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: if id is None and not get_or_create: raise ValueError("id must be specified if get_or_create is False") - existing = self.get_collections(name=name) + existing = self.get_collections(name=name, tenant=tenant, database=database) if existing: if get_or_create: collection = existing[0] @@ -101,7 +135,12 @@ def create_collection( collection["id"], metadata=metadata, ) - return self.get_collections(id=collection["id"])[0], False + return ( + self.get_collections( + id=collection["id"], tenant=tenant, database=database + )[0], + False, + ) else: raise UniqueConstraintError(f"Collection {name} already exists") @@ -112,6 +151,8 @@ def create_collection( with self.tx() as cur: collections = Table("collections") + databases = Table("databases") + insert_collection = ( self.querybuilder() .into(collections) @@ -120,12 +161,19 @@ def create_collection( collections.topic, collections.name, collections.dimension, + collections.database_id, ) .insert( ParameterValue(self.uuid_to_db(collection["id"])), ParameterValue(collection["topic"]), ParameterValue(collection["name"]), ParameterValue(collection["dimension"]), + # Get the database id for the database with the given name and tenant + self.querybuilder() + .select(databases.id) + .from_(databases) + .where(databases.name == ParameterValue(database)) + .where(databases.tenant_id == ParameterValue(tenant)), ) ) sql, params = get_sql(insert_collection, self.parameter_format()) @@ -220,8 +268,16 @@ def get_collections( id: Optional[UUID] = None, topic: Optional[str] = None, name: Optional[str] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: """Get collections by name, embedding function and/or metadata""" + + if name is not None and (tenant is None or database is None): + raise ValueError( + "If name is specified, tenant and database must also be specified in order to uniquely identify the collection" + ) + collections_t = Table("collections") metadata_t = Table("collection_metadata") q = ( @@ -248,6 +304,17 @@ def get_collections( if name: q = q.where(collections_t.name == ParameterValue(name)) + if tenant and database: + databases_t = Table("databases") + q = q.where( + collections_t.database_id + == self.querybuilder() + .select(databases_t.id) + .from_(databases_t) + .where(databases_t.name == ParameterValue(database)) + .where(databases_t.tenant_id == ParameterValue(tenant)) + ) + with self.tx() as cur: sql, params = get_sql(q, self.parameter_format()) rows = cur.execute(sql, params).fetchall() @@ -291,13 +358,27 @@ def delete_segment(self, id: UUID) -> None: raise NotFoundError(f"Segment {id} not found") @override - def delete_collection(self, id: UUID) -> None: + def delete_collection( + self, + id: UUID, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: """Delete a topic and all associated segments from the SysDB""" t = Table("collections") + databases_t = Table("databases") q = ( self.querybuilder() .from_(t) .where(t.id == ParameterValue(self.uuid_to_db(id))) + .where( + t.database_id + == self.querybuilder() + .select(databases_t.id) + .from_(databases_t) + .where(databases_t.name == ParameterValue(database)) + .where(databases_t.tenant_id == ParameterValue(tenant)) + ) .delete() ) with self.tx() as cur: @@ -391,7 +472,10 @@ def update_collection( with self.tx() as cur: sql, params = get_sql(q, self.parameter_format()) if sql: # pypika emits a blank string if nothing to do - cur.execute(sql, params) + sql = sql + " RETURNING id" + result = cur.execute(sql, params) + if not result.fetchone(): + raise NotFoundError(f"Collection {id} not found") # TODO: Update to use better semantics where it's possible to update # individual keys without wiping all the existing metadata. diff --git a/chromadb/db/system.py b/chromadb/db/system.py index b4975339fbc..b3506062563 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -10,12 +10,20 @@ Unspecified, UpdateMetadata, ) -from chromadb.config import Component +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component class SysDB(Component): """Data interface for Chroma's System database""" + @abstractmethod + def create_database( + self, id: UUID, name: str, tenant: str = DEFAULT_TENANT + ) -> None: + """Create a new database in the System database. Raises DuplicateError if the Database + already exists.""" + pass + @abstractmethod def create_segment(self, segment: Segment) -> None: """Create a new segment in the System database. Raises DuplicateError if the ID @@ -60,6 +68,8 @@ def create_collection( metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: """Create a new collection any associated resources (Such as the necessary topics) in the SysDB. If get_or_create is True, the @@ -73,7 +83,9 @@ def create_collection( pass @abstractmethod - def delete_collection(self, id: UUID) -> None: + def delete_collection( + self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> None: """Delete a collection, topic, all associated segments and any associate resources from the SysDB and the system at large.""" pass @@ -84,8 +96,10 @@ def get_collections( id: Optional[UUID] = None, topic: Optional[str] = None, name: Optional[str] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: - """Find collections by id, topic or name""" + """Find collections by id, topic or name. If name is provided, tenant and database must also be provided.""" pass @abstractmethod diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql index 1c40e823480..e1386efba42 100644 --- a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -1,19 +1,29 @@ -CREATE TABLE tenants ( +CREATE TABLE tenants ( -- todo: make this idempotent by checking if table exists by using CREATE TABLE IF NOT EXISTS id TEXT PRIMARY KEY, - name TEXT NOT NULL, - UNIQUE (name) -- Maybe not needed since we want to support slug ids + UNIQUE (id) -- Maybe not needed since we want to support slug ids ); CREATE TABLE databases ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, + id TEXT PRIMARY KEY, -- unique globally + name TEXT NOT NULL, -- unique per tenant tenant_id TEXT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, - UNIQUE (name) + UNIQUE (tenant_id, name) -- Ensure that a tenant has only one database with a given name ); -ALTER TABLE collections - ADD COLUMN database_id TEXT NOT NULL REFERENCES databases(id) DEFAULT 'default'; -- ON DELETE CASCADE not supported by sqlite in ALTER TABLE +CREATE TABLE collections_tmp ( + id TEXT PRIMARY KEY, -- unique globally + name TEXT NOT NULL, -- unique per database + topic TEXT NOT NULL, + dimension INTEGER, + database_id TEXT NOT NULL REFERENCES databases(id) ON DELETE CASCADE, + UNIQUE (name, database_id) +); -- Create default tenant and database -INSERT INTO tenants (id, name) VALUES ('default', 'default'); +INSERT INTO tenants (id) VALUES ('default'); -- should ids be uuids? INSERT INTO databases (id, name, tenant_id) VALUES ('default', 'default', 'default'); + +INSERT INTO collections_tmp (id, name, topic, dimension, database_id) + SELECT id, name, topic, dimension, 'default' FROM collections; +DROP TABLE collections; +ALTER TABLE collections_tmp RENAME TO collections; diff --git a/chromadb/proto/coordinator_pb2.py b/chromadb/proto/coordinator_pb2.py index 118405d423a..35459ff1ae6 100644 --- a/chromadb/proto/coordinator_pb2.py +++ b/chromadb/proto/coordinator_pb2.py @@ -15,35 +15,37 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xc3\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"%\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"i\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xb6\x05\n\x05SysDB\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"A\n\x15\x43reateDatabaseRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xe5\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x06 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x07 \x01(\tB\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"G\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\"\x8b\x01\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x04 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x05 \x01(\tB\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\x81\x06\n\x05SysDB\x12I\n\x0e\x43reateDatabase\x12\x1d.chroma.CreateDatabaseRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.coordinator_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATESEGMENTREQUEST']._serialized_start=102 - _globals['_CREATESEGMENTREQUEST']._serialized_end=158 - _globals['_DELETESEGMENTREQUEST']._serialized_start=160 - _globals['_DELETESEGMENTREQUEST']._serialized_end=194 - _globals['_GETSEGMENTSREQUEST']._serialized_start=197 - _globals['_GETSEGMENTSREQUEST']._serialized_end=391 - _globals['_GETSEGMENTSRESPONSE']._serialized_start=393 - _globals['_GETSEGMENTSRESPONSE']._serialized_end=481 - _globals['_UPDATESEGMENTREQUEST']._serialized_start=484 - _globals['_UPDATESEGMENTREQUEST']._serialized_end=734 - _globals['_CREATECOLLECTIONREQUEST']._serialized_start=737 - _globals['_CREATECOLLECTIONREQUEST']._serialized_end=932 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=934 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1049 - _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1051 - _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1088 - _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1090 - _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1195 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1197 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1294 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1297 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1519 - _globals['_SYSDB']._serialized_start=1522 - _globals['_SYSDB']._serialized_end=2216 + _globals['_CREATEDATABASEREQUEST']._serialized_start=102 + _globals['_CREATEDATABASEREQUEST']._serialized_end=167 + _globals['_CREATESEGMENTREQUEST']._serialized_start=169 + _globals['_CREATESEGMENTREQUEST']._serialized_end=225 + _globals['_DELETESEGMENTREQUEST']._serialized_start=227 + _globals['_DELETESEGMENTREQUEST']._serialized_end=261 + _globals['_GETSEGMENTSREQUEST']._serialized_start=264 + _globals['_GETSEGMENTSREQUEST']._serialized_end=458 + _globals['_GETSEGMENTSRESPONSE']._serialized_start=460 + _globals['_GETSEGMENTSRESPONSE']._serialized_end=548 + _globals['_UPDATESEGMENTREQUEST']._serialized_start=551 + _globals['_UPDATESEGMENTREQUEST']._serialized_end=801 + _globals['_CREATECOLLECTIONREQUEST']._serialized_start=804 + _globals['_CREATECOLLECTIONREQUEST']._serialized_end=1033 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=1035 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1150 + _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1152 + _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1223 + _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1226 + _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1365 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1367 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1464 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1467 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1689 + _globals['_SYSDB']._serialized_start=1692 + _globals['_SYSDB']._serialized_end=2461 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/coordinator_pb2.pyi b/chromadb/proto/coordinator_pb2.pyi index 6b9c974e424..a77c0d828a9 100644 --- a/chromadb/proto/coordinator_pb2.pyi +++ b/chromadb/proto/coordinator_pb2.pyi @@ -7,6 +7,16 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor +class CreateDatabaseRequest(_message.Message): + __slots__ = ["id", "name", "tenant"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + tenant: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... + class CreateSegmentRequest(_message.Message): __slots__ = ["segment"] SEGMENT_FIELD_NUMBER: _ClassVar[int] @@ -60,18 +70,22 @@ class UpdateSegmentRequest(_message.Message): def __init__(self, id: _Optional[str] = ..., topic: _Optional[str] = ..., reset_topic: bool = ..., collection: _Optional[str] = ..., reset_collection: bool = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... class CreateCollectionRequest(_message.Message): - __slots__ = ["id", "name", "metadata", "dimension", "get_or_create"] + __slots__ = ["id", "name", "metadata", "dimension", "get_or_create", "tenant", "database"] ID_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] METADATA_FIELD_NUMBER: _ClassVar[int] DIMENSION_FIELD_NUMBER: _ClassVar[int] GET_OR_CREATE_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] id: str name: str metadata: _chroma_pb2.UpdateMetadata dimension: int get_or_create: bool - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., get_or_create: bool = ...) -> None: ... + tenant: str + database: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., get_or_create: bool = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ...) -> None: ... class CreateCollectionResponse(_message.Message): __slots__ = ["collection", "created", "status"] @@ -84,20 +98,28 @@ class CreateCollectionResponse(_message.Message): def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., created: bool = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... class DeleteCollectionRequest(_message.Message): - __slots__ = ["id"] + __slots__ = ["id", "tenant", "database"] ID_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] id: str - def __init__(self, id: _Optional[str] = ...) -> None: ... + tenant: str + database: str + def __init__(self, id: _Optional[str] = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ...) -> None: ... class GetCollectionsRequest(_message.Message): - __slots__ = ["id", "name", "topic"] + __slots__ = ["id", "name", "topic", "tenant", "database"] ID_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] TOPIC_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] id: str name: str topic: str - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ...) -> None: ... + tenant: str + database: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ...) -> None: ... class GetCollectionsResponse(_message.Message): __slots__ = ["collections", "status"] diff --git a/chromadb/proto/coordinator_pb2_grpc.py b/chromadb/proto/coordinator_pb2_grpc.py index a3a1e03227b..d8fe2eb147a 100644 --- a/chromadb/proto/coordinator_pb2_grpc.py +++ b/chromadb/proto/coordinator_pb2_grpc.py @@ -16,6 +16,11 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ + self.CreateDatabase = channel.unary_unary( + "/chroma.SysDB/CreateDatabase", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) self.CreateSegment = channel.unary_unary( "/chroma.SysDB/CreateSegment", request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, @@ -66,6 +71,12 @@ def __init__(self, channel): class SysDBServicer(object): """Missing associated documentation comment in .proto file.""" + def CreateDatabase(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def CreateSegment(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -123,6 +134,11 @@ def ResetState(self, request, context): def add_SysDBServicer_to_server(servicer, server): rpc_method_handlers = { + "CreateDatabase": grpc.unary_unary_rpc_method_handler( + servicer.CreateDatabase, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), "CreateSegment": grpc.unary_unary_rpc_method_handler( servicer.CreateSegment, request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.FromString, @@ -179,6 +195,35 @@ def add_SysDBServicer_to_server(servicer, server): class SysDB(object): """Missing associated documentation comment in .proto file.""" + @staticmethod + def CreateDatabase( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/CreateDatabase", + chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def CreateSegment( request, diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 541643a2ff6..b40bc32108e 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -3,6 +3,7 @@ import tempfile import pytest from typing import Generator, List, Callable, Dict, Union + from chromadb.db.impl.grpc.client import GrpcSysDB from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.types import Collection, Segment, SegmentScope @@ -107,6 +108,7 @@ def sysdb(request: FixtureRequest) -> Generator[SysDB, None, None]: yield next(request.param()) +# region Collection tests def test_create_get_delete_collections(sysdb: SysDB) -> None: sysdb.reset_state() @@ -296,6 +298,171 @@ def test_get_or_create_collection(sysdb: SysDB) -> None: assert result["metadata"] == overlayed_metadata +def test_create_get_delete_database_and_collection(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new database + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create a new collection in the new database + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + ) + + # Create a new collection with the same id but different name in the new database + # and expect an error + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[0]["id"], + name="new_name", + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + get_or_create=False, + ) + + # Create a new collection in the default database + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + ) + + # Check that the new database and collections exist + result = sysdb.get_collections( + name=sample_collections[0]["name"], database="new_database" + ) + assert len(result) == 1 + assert result[0] == sample_collections[0] + + # Check that the collection in the default database exists + result = sysdb.get_collections(name=sample_collections[1]["name"]) + assert len(result) == 1 + assert result[0] == sample_collections[1] + + # Get for a database that doesn't exist with a name that exists in the new database and expect no results + assert ( + len( + sysdb.get_collections( + name=sample_collections[0]["name"], database="fake_db" + ) + ) + == 0 + ) + + # Delete the collection in the new database + sysdb.delete_collection(id=sample_collections[0]["id"], database="new_database") + + # Check that the collection in the new database was deleted + result = sysdb.get_collections(database="new_database") + assert len(result) == 0 + + # Check that the collection in the default database still exists + result = sysdb.get_collections(name=sample_collections[1]["name"]) + assert len(result) == 1 + assert result[0] == sample_collections[1] + + # Delete the deleted collection in the default database and expect an error + with pytest.raises(NotFoundError): + sysdb.delete_collection(id=sample_collections[0]["id"]) + + # Delete the existing collection in the new database and expect an error + with pytest.raises(NotFoundError): + sysdb.delete_collection(id=sample_collections[1]["id"], database="new_database") + + +def test_create_update_with_database(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new database + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create a new collection in the new database + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + ) + + # Create a new collection in the default database + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + ) + + # Update the collection in the default database + sysdb.update_collection( + id=sample_collections[1]["id"], + name="new_name_1", + ) + + # Check that the collection in the default database was updated + result = sysdb.get_collections(id=sample_collections[1]["id"]) + assert len(result) == 1 + assert result[0]["name"] == "new_name_1" + + # Update the collection in the new database + sysdb.update_collection( + id=sample_collections[0]["id"], + name="new_name_0", + ) + + # Check that the collection in the new database was updated + result = sysdb.get_collections( + id=sample_collections[0]["id"], database="new_database" + ) + assert len(result) == 1 + assert result[0]["name"] == "new_name_0" + + # Try to create the collection in the default database in the new database and expect an error + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + database="new_database", + ) + + +def test_get_multiple_with_database(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new database + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create sample collections in the new database + for collection in sample_collections: + sysdb.create_collection( + id=collection["id"], + name=collection["name"], + metadata=collection["metadata"], + dimension=collection["dimension"], + database="new_database", + ) + + # Get all collections in the new database + result = sysdb.get_collections(database="new_database") + assert len(result) == len(sample_collections) + assert sorted(result, key=lambda c: c["name"]) == sample_collections + + # Get all collections in the default database + result = sysdb.get_collections() + assert len(result) == 0 + + +# endregion + +# region Segment tests sample_segments = [ Segment( id=uuid.UUID("00000000-d7d7-413b-92e1-731098a6e492"), @@ -459,3 +626,6 @@ def test_update_segment(sysdb: SysDB) -> None: sysdb.update_segment(segment["id"], metadata=None) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] + + +# endregion diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 2a557f99613..4ff45d97bbd 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -5,6 +5,12 @@ package chroma; import "chromadb/proto/chroma.proto"; import "google/protobuf/empty.proto"; +message CreateDatabaseRequest { + string id = 1; + string name = 2; + string tenant = 3; +} + message CreateSegmentRequest { Segment segment = 1; } @@ -18,7 +24,7 @@ message GetSegmentsRequest { optional string type = 2; optional SegmentScope scope = 3; optional string topic = 4; - optional string collection = 5; + optional string collection = 5; // Collection ID } message GetSegmentsResponse { @@ -49,6 +55,8 @@ message CreateCollectionRequest { optional UpdateMetadata metadata = 3; optional int32 dimension = 4; optional bool get_or_create = 5; + string tenant = 6; + string database = 7; } message CreateCollectionResponse { @@ -59,12 +67,16 @@ message CreateCollectionResponse { message DeleteCollectionRequest { string id = 1; + string tenant = 2; + string database = 3; } message GetCollectionsRequest { optional string id = 1; optional string name = 2; optional string topic = 3; + string tenant = 4; + string database = 5; } message GetCollectionsResponse { @@ -84,6 +96,7 @@ message UpdateCollectionRequest { } service SysDB { + rpc CreateDatabase(CreateDatabaseRequest) returns (ChromaResponse) {} rpc CreateSegment(CreateSegmentRequest) returns (ChromaResponse) {} rpc DeleteSegment(DeleteSegmentRequest) returns (ChromaResponse) {} rpc GetSegments(GetSegmentsRequest) returns (GetSegmentsResponse) {}