diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index f2a8099e4fc..aa8650e748c 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -396,7 +396,7 @@ class ClientAPI(BaseAPI, ABC): database: str @abstractmethod - def set_tenant_and_database(self, tenant: str, database: str) -> None: + def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None: """Set the tenant and database for the client. Raises an error if the tenant or database does not exist. @@ -407,6 +407,16 @@ def set_tenant_and_database(self, tenant: str, database: str) -> None: """ pass + @abstractmethod + def set_database(self, database: str) -> None: + """Set the database for the client. Raises an error if the database does not exist. + + Args: + database: The database to set. + + """ + pass + @staticmethod @abstractmethod def clear_system_cache() -> None: diff --git a/chromadb/api/client.py b/chromadb/api/client.py index c7d9af971fc..e33bc30a451 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -389,11 +389,16 @@ def max_batch_size(self) -> int: # region ClientAPI Methods @override - def set_tenant_and_database(self, tenant: str, database: str) -> None: + def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None: self._validate_tenant_database(tenant=tenant, database=database) self.tenant = tenant self.database = database + @override + def set_database(self, database: str) -> None: + self._validate_tenant_database(tenant=self.tenant, database=database) + self.database = database + def _validate_tenant_database(self, tenant: str, database: str) -> None: try: self._admin_client.get_tenant(name=tenant) diff --git a/chromadb/test/client/test_database_tenant.py b/chromadb/test/client/test_database_tenant.py index 55672d36aac..d8778c536f4 100644 --- a/chromadb/test/client/test_database_tenant.py +++ b/chromadb/test/client/test_database_tenant.py @@ -9,11 +9,11 @@ def test_database_tenant_collections(client: Client) -> None: admin_client.create_database("test_db") # Create collections in this new database - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") client.create_collection("collection", metadata={"database": "test_db"}) # Create collections in the default database - client.set_database("default") + client.set_tenant(tenant="default", database="default") client.create_collection("collection", metadata={"database": "default"}) # List collections in the default database @@ -23,37 +23,37 @@ def test_database_tenant_collections(client: Client) -> None: assert collections[0].metadata == {"database": "default"} # List collections in the new database - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") collections = client.list_collections() assert len(collections) == 1 assert collections[0].metadata == {"database": "test_db"} # Update the metadata in both databases to different values - client.set_database("default") + client.set_tenant(tenant="default", database="default") client.list_collections()[0].modify(metadata={"database": "default2"}) - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") client.list_collections()[0].modify(metadata={"database": "test_db2"}) # Validate that the metadata was updated - client.set_database("default") + client.set_tenant(tenant="default", database="default") collections = client.list_collections() assert len(collections) == 1 assert collections[0].metadata == {"database": "default2"} - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") collections = client.list_collections() assert len(collections) == 1 assert collections[0].metadata == {"database": "test_db2"} # Delete the collections and make sure databases are isolated - client.set_database("default") + client.set_tenant(tenant="default", database="default") client.delete_collection("collection") collections = client.list_collections() assert len(collections) == 0 - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") collections = client.list_collections() assert len(collections) == 1 diff --git a/chromadb/test/property/test_collections_with_database_tenant.py b/chromadb/test/property/test_collections_with_database_tenant.py index 4d95b2ec97a..28ba14f092a 100644 --- a/chromadb/test/property/test_collections_with_database_tenant.py +++ b/chromadb/test/property/test_collections_with_database_tenant.py @@ -44,7 +44,7 @@ def initialize(self) -> None: self.tenant_to_database_to_model = {} self.curr_tenant = DEFAULT_TENANT self.curr_database = DEFAULT_DATABASE - self.api.set_tenant_and_database(DEFAULT_TENANT, DEFAULT_DATABASE) + self.api.set_tenant(DEFAULT_TENANT, DEFAULT_DATABASE) self.tenant_to_database_to_model[self.curr_tenant] = {} self.tenant_to_database_to_model[self.curr_tenant][self.curr_database] = {} @@ -82,13 +82,13 @@ def set_database_and_tenant(self, database: Tuple[str, str]) -> None: # Get a database and switch to the database and the tenant it belongs to database_name = database[0] tenant_name = database[1] - self.api.set_tenant_and_database(tenant_name, database_name) + self.api.set_tenant(tenant_name, database_name) self.curr_database = database_name self.curr_tenant = tenant_name @rule(tenant=tenants) def set_tenant(self, tenant: str) -> None: - self.api.set_tenant_and_database(tenant, DEFAULT_DATABASE) + self.api.set_tenant(tenant, DEFAULT_DATABASE) self.curr_tenant = tenant self.curr_database = DEFAULT_DATABASE