From 2bd9df122fee2a653c5dc5c3030c190ac88bd240 Mon Sep 17 00:00:00 2001 From: hammadb Date: Thu, 19 Oct 2023 11:54:17 -0700 Subject: [PATCH] plumbing + concurrency test --- chromadb/api/__init__.py | 3 + chromadb/api/client.py | 42 ++++++++++++- chromadb/api/fastapi.py | 25 ++++++-- chromadb/api/segment.py | 4 +- chromadb/server/fastapi/__init__.py | 55 +++++++++++++---- chromadb/server/fastapi/types.py | 4 ++ chromadb/test/client/test_database_tenant.py | 60 +++++++++++++++++++ .../test_multiple_clients_concurrency.py | 43 +++++++++++++ chromadb/test/conftest.py | 10 +++- 9 files changed, 227 insertions(+), 19 deletions(-) create mode 100644 chromadb/test/client/test_database_tenant.py create mode 100644 chromadb/test/client/test_multiple_clients_concurrency.py diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 28a62fcb5be..6c0e56e5c8f 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -391,6 +391,9 @@ def max_batch_size(self) -> int: class ClientAPI(BaseAPI, ABC): + tenant: str + database: str + @abstractmethod def set_database(self, database: str) -> None: """Set the database for the client. diff --git a/chromadb/api/client.py b/chromadb/api/client.py index 24b087cc30c..3589af0192f 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, Optional, Sequence +from typing import ClassVar, Dict, Optional, Sequence, TypeVar from uuid import UUID from overrides import override @@ -22,6 +22,8 @@ from chromadb.types import Where, WhereDocument import chromadb.utils.embedding_functions as ef +C = TypeVar("C", "SharedSystemClient", "Client", "AdminClient") + class SharedSystemClient: _identifer_to_system: ClassVar[Dict[str, System]] = {} @@ -81,6 +83,20 @@ def _get_identifier_from_settings(settings: Settings) -> str: return identifier + @staticmethod + def _populate_data_from_system(system: System) -> str: + identifier = SharedSystemClient._get_identifier_from_settings(system.settings) + SharedSystemClient._identifer_to_system[identifier] = system + return identifier + + @classmethod + def from_system(cls, system: System) -> "SharedSystemClient": + """Create a client from an existing system. This is useful for testing and debugging.""" + + SharedSystemClient._populate_data_from_system(system) + instance = cls(system.settings) + return instance + @staticmethod def clear_system_cache() -> None: SharedSystemClient._identifer_to_system = {} @@ -125,6 +141,18 @@ def __init__( telemetry_client = self._system.instance(Telemetry) telemetry_client.capture(ClientStartEvent()) + @classmethod + @override + def from_system( + cls, + system: System, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> "Client": + SharedSystemClient._populate_data_from_system(system) + instance = cls(tenant=tenant, database=database, settings=system.settings) + return instance + # endregion # region BaseAPI Methods @@ -366,7 +394,7 @@ def set_tenant(self, tenant: str) -> None: # endregion -class AdminClient(AdminAPI, SharedSystemClient): +class AdminClient(SharedSystemClient, AdminAPI): _server: ServerAPI def __init__(self, settings: Settings = Settings()) -> None: @@ -376,3 +404,13 @@ def __init__(self, settings: Settings = Settings()) -> None: @override def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: return self._server.create_database(name=name, tenant=tenant) + + @classmethod + @override + def from_system( + cls, + system: System, + ) -> "AdminClient": + SharedSystemClient._populate_data_from_system(system) + instance = cls(settings=system.settings) + return instance diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 784d9e79dbb..150edaf7dad 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -143,7 +143,8 @@ def create_database( """Creates a database""" resp = self._session.post( self._api_url + "/databases", - data=json.dumps({"name": name, "tenant": tenant}), + data=json.dumps({"name": name}), + params={"tenant": tenant}, ) raise_chroma_error(resp) @@ -152,7 +153,10 @@ def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> Sequence[Collection]: """Returns a list of all collections""" - resp = self._session.get(self._api_url + "/collections") + resp = self._session.get( + self._api_url + "/collections", + params={"tenant": tenant, "database": database}, + ) raise_chroma_error(resp) json_collections = resp.json() collections = [] @@ -175,8 +179,13 @@ def create_collection( resp = self._session.post( self._api_url + "/collections", data=json.dumps( - {"name": name, "metadata": metadata, "get_or_create": get_or_create} + { + "name": name, + "metadata": metadata, + "get_or_create": get_or_create, + } ), + params={"tenant": tenant, "database": database}, ) raise_chroma_error(resp) resp_json = resp.json() @@ -197,7 +206,10 @@ def get_collection( database: str = DEFAULT_DATABASE, ) -> Collection: """Returns a collection""" - resp = self._session.get(self._api_url + "/collections/" + name) + resp = self._session.get( + self._api_url + "/collections/" + name, + params={"tenant": tenant, "database": database}, + ) raise_chroma_error(resp) resp_json = resp.json() return Collection( @@ -248,7 +260,10 @@ def delete_collection( database: str = DEFAULT_DATABASE, ) -> None: """Deletes a collection""" - resp = self._session.delete(self._api_url + "/collections/" + name) + resp = self._session.delete( + self._api_url + "/collections/" + name, + params={"tenant": tenant, "database": database}, + ) raise_chroma_error(resp) @override diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 37887d5432a..a959f7a2be3 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -252,7 +252,9 @@ def delete_collection( ) if existing: - self._sysdb.delete_collection(existing[0]["id"]) + self._sysdb.delete_collection( + existing[0]["id"], tenant=tenant, database=database + ) for s in self._manager.delete_segments(existing[0]["id"]): self._sysdb.delete_segment(s) if existing and existing[0]["id"] in self._collection_cache: diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index ec8bcac7ea1..ee0e74eb2ea 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -15,7 +15,7 @@ FastAPIChromaAuthMiddleware, FastAPIChromaAuthMiddlewareWrapper, ) -from chromadb.config import Settings, System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System import chromadb.server import chromadb.api from chromadb.api import ServerAPI @@ -26,6 +26,7 @@ ) from chromadb.server.fastapi.types import ( AddEmbedding, + CreateDatabase, DeleteEmbedding, GetEmbedding, QueryEmbedding, @@ -100,7 +101,6 @@ def include_in_schema(path: str) -> bool: super().add_api_route(path, *args, **kwargs) -# TODO: add tenant/namespace to all routes class FastAPI(chromadb.server.Server): def __init__(self, settings: Settings): super().__init__(settings) @@ -134,6 +134,13 @@ def __init__(self, settings: Settings): "/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"] ) + self.router.add_api_route( + "/api/v1/databases", + self.create_database, + methods=["POST"], + response_model=None, + ) + self.router.add_api_route( "/api/v1/collections", self.list_collections, @@ -225,18 +232,39 @@ def heartbeat(self) -> Dict[str, int]: def version(self) -> str: return self._api.get_version() - def list_collections(self) -> Sequence[Collection]: - return self._api.list_collections() - - def create_collection(self, collection: CreateCollection) -> Collection: + def create_database( + self, database: CreateDatabase, tenant: str = DEFAULT_TENANT + ) -> None: + return self._api.create_database(database.name, tenant) + + def list_collections( + self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> Sequence[Collection]: + return self._api.list_collections(tenant=tenant, database=database) + + def create_collection( + self, + collection: CreateCollection, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: return self._api.create_collection( name=collection.name, metadata=collection.metadata, get_or_create=collection.get_or_create, + tenant=tenant, + database=database, ) - def get_collection(self, collection_name: str) -> Collection: - return self._api.get_collection(collection_name) + def get_collection( + self, + collection_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + return self._api.get_collection( + collection_name, tenant=tenant, database=database + ) def update_collection( self, collection_id: str, collection: UpdateCollection @@ -247,8 +275,15 @@ def update_collection( new_metadata=collection.new_metadata, ) - def delete_collection(self, collection_name: str) -> None: - return self._api.delete_collection(collection_name) + def delete_collection( + self, + collection_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: + return self._api.delete_collection( + collection_name, tenant=tenant, database=database + ) def add(self, collection_id: str, add: AddEmbedding) -> None: try: diff --git a/chromadb/server/fastapi/types.py b/chromadb/server/fastapi/types.py index 306f0e5fcb3..a27b6162df6 100644 --- a/chromadb/server/fastapi/types.py +++ b/chromadb/server/fastapi/types.py @@ -59,3 +59,7 @@ class CreateCollection(BaseModel): # type: ignore class UpdateCollection(BaseModel): # type: ignore new_name: Optional[str] = None new_metadata: Optional[CollectionMetadata] = None + + +class CreateDatabase(BaseModel): + name: str diff --git a/chromadb/test/client/test_database_tenant.py b/chromadb/test/client/test_database_tenant.py new file mode 100644 index 00000000000..8a3f4aff479 --- /dev/null +++ b/chromadb/test/client/test_database_tenant.py @@ -0,0 +1,60 @@ +from chromadb.api.client import AdminClient, Client + + +def test_database_tenant_collections(client: Client) -> None: + # Create a new database in the default tenant + admin_client = AdminClient.from_system(client._system) + admin_client.create_database("test_db") + + # Create collections in this new database + client.set_database("test_db") + client.create_collection("collection", metadata={"database": "test_db"}) + + # Create collections in the default database + client.set_database("default") + client.create_collection("collection", metadata={"database": "default"}) + + # List collections in the default database + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].name == "collection" + assert collections[0].metadata == {"database": "default"} + + # List collections in the new database + client.set_database("test_db") + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].metadata == {"database": "test_db"} + + # Update the metadata in both databases to different values + client.set_database("default") + client.list_collections()[0].modify(metadata={"database": "default2"}) + + client.set_database("test_db") + client.list_collections()[0].modify(metadata={"database": "test_db2"}) + + # Validate that the metadata was updated + client.set_database("default") + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].metadata == {"database": "default2"} + + client.set_database("test_db") + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].metadata == {"database": "test_db2"} + + # Delete the collections and make sure databases are isolated + client.set_database("default") + client.delete_collection("collection") + + collections = client.list_collections() + assert len(collections) == 0 + + client.set_database("test_db") + collections = client.list_collections() + assert len(collections) == 1 + + client.delete_collection("collection") + collections = client.list_collections() + assert len(collections) == 0 diff --git a/chromadb/test/client/test_multiple_clients_concurrency.py b/chromadb/test/client/test_multiple_clients_concurrency.py new file mode 100644 index 00000000000..15e467bd8b3 --- /dev/null +++ b/chromadb/test/client/test_multiple_clients_concurrency.py @@ -0,0 +1,43 @@ +from concurrent.futures import ThreadPoolExecutor +from chromadb.api.client import AdminClient, Client + + +def test_multiple_clients_concurrently(client: Client) -> None: + """Tests running multiple clients, each against their own database, concurrently.""" + admin_client = AdminClient.from_system(client._system) + admin_client.create_database("test_db") + + CLIENT_COUNT = 100 + COLLECTION_COUNT = 500 + + # Each database will create the same collections by name, with differing metadata + databases = [f"db{i}" for i in range(CLIENT_COUNT)] + for database in databases: + admin_client.create_database(database) + + collections = [f"collection{i}" for i in range(COLLECTION_COUNT)] + + # Create N clients, each on a seperate thread, each with their own database + def run_target(n: int) -> None: + thread_client = Client( + tenant="default", database=databases[n], settings=client._system.settings + ) + for collection in collections: + thread_client.create_collection( + collection, metadata={"database": databases[n]} + ) + + with ThreadPoolExecutor(max_workers=CLIENT_COUNT) as executor: + executor.map(run_target, range(CLIENT_COUNT)) + + # Create a final client, which will be used to verify the collections were created + client = Client(settings=client._system.settings) + + # Verify that the collections were created + for database in databases: + client.set_database(database) + seen_collections = client.list_collections() + assert len(seen_collections) == COLLECTION_COUNT + for collection in seen_collections: + assert collection.name in collections + assert collection.metadata == {"database": database} diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 238748bcbbb..b27e1af7c01 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -22,11 +22,12 @@ from typing_extensions import Protocol import chromadb.server.fastapi -from chromadb.api import ServerAPI +from chromadb.api import ClientAPI, ServerAPI from chromadb.config import Settings, System from chromadb.db.mixins import embeddings_queue from chromadb.ingest import Producer from chromadb.types import SeqId, SubmitEmbeddingRecord +from chromadb.api.client import Client as ClientCreator root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) # This will only run when testing @@ -384,6 +385,13 @@ def api(system: System) -> Generator[ServerAPI, None, None]: yield api +@pytest.fixture(scope="function") +def client(system: System) -> Generator[ClientAPI, None, None]: + system.reset_state() + client = ClientCreator.from_system(system) + yield client + + @pytest.fixture(scope="function") def api_wrong_cred( system_wrong_auth: System,