Skip to content

Commit

Permalink
Add set_database and change set_tenant to take database
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Oct 23, 2023
1 parent 6bcb1fd commit 19a57f6
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 14 deletions.
12 changes: 11 additions & 1 deletion chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ class ClientAPI(BaseAPI, ABC):
database: str

@abstractmethod
def set_tenant_and_database(self, tenant: str, database: str) -> None:
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
"""Set the tenant and database for the client. Raises an error if the tenant or
database does not exist.
Expand All @@ -407,6 +407,16 @@ def set_tenant_and_database(self, tenant: str, database: str) -> None:
"""
pass

@abstractmethod
def set_database(self, database: str) -> None:
"""Set the database for the client. Raises an error if the database does not exist.
Args:
database: The database to set.
"""
pass

@staticmethod
@abstractmethod
def clear_system_cache() -> None:
Expand Down
7 changes: 6 additions & 1 deletion chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,16 @@ def max_batch_size(self) -> int:
# region ClientAPI Methods

@override
def set_tenant_and_database(self, tenant: str, database: str) -> None:
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
self._validate_tenant_database(tenant=tenant, database=database)
self.tenant = tenant
self.database = database

@override
def set_database(self, database: str) -> None:
self._validate_tenant_database(tenant=self.tenant, database=database)
self.database = database

def _validate_tenant_database(self, tenant: str, database: str) -> None:
try:
self._admin_client.get_tenant(name=tenant)
Expand Down
18 changes: 9 additions & 9 deletions chromadb/test/client/test_database_tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ def test_database_tenant_collections(client: Client) -> None:
admin_client.create_database("test_db")

# Create collections in this new database
client.set_database("test_db")
client.set_tenant(tenant="default", database="test_db")
client.create_collection("collection", metadata={"database": "test_db"})

# Create collections in the default database
client.set_database("default")
client.set_tenant(tenant="default", database="default")
client.create_collection("collection", metadata={"database": "default"})

# List collections in the default database
Expand All @@ -23,37 +23,37 @@ def test_database_tenant_collections(client: Client) -> None:
assert collections[0].metadata == {"database": "default"}

# List collections in the new database
client.set_database("test_db")
client.set_tenant(tenant="default", database="test_db")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "test_db"}

# Update the metadata in both databases to different values
client.set_database("default")
client.set_tenant(tenant="default", database="default")
client.list_collections()[0].modify(metadata={"database": "default2"})

client.set_database("test_db")
client.set_tenant(tenant="default", database="test_db")
client.list_collections()[0].modify(metadata={"database": "test_db2"})

# Validate that the metadata was updated
client.set_database("default")
client.set_tenant(tenant="default", database="default")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "default2"}

client.set_database("test_db")
client.set_tenant(tenant="default", database="test_db")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "test_db2"}

# Delete the collections and make sure databases are isolated
client.set_database("default")
client.set_tenant(tenant="default", database="default")
client.delete_collection("collection")

collections = client.list_collections()
assert len(collections) == 0

client.set_database("test_db")
client.set_tenant(tenant="default", database="test_db")
collections = client.list_collections()
assert len(collections) == 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def initialize(self) -> None:
self.tenant_to_database_to_model = {}
self.curr_tenant = DEFAULT_TENANT
self.curr_database = DEFAULT_DATABASE
self.api.set_tenant_and_database(DEFAULT_TENANT, DEFAULT_DATABASE)
self.api.set_tenant(DEFAULT_TENANT, DEFAULT_DATABASE)
self.tenant_to_database_to_model[self.curr_tenant] = {}
self.tenant_to_database_to_model[self.curr_tenant][self.curr_database] = {}

Expand Down Expand Up @@ -82,13 +82,13 @@ def set_database_and_tenant(self, database: Tuple[str, str]) -> None:
# Get a database and switch to the database and the tenant it belongs to
database_name = database[0]
tenant_name = database[1]
self.api.set_tenant_and_database(tenant_name, database_name)
self.api.set_tenant(tenant_name, database_name)
self.curr_database = database_name
self.curr_tenant = tenant_name

@rule(tenant=tenants)
def set_tenant(self, tenant: str) -> None:
self.api.set_tenant_and_database(tenant, DEFAULT_DATABASE)
self.api.set_tenant(tenant, DEFAULT_DATABASE)
self.curr_tenant = tenant
self.curr_database = DEFAULT_DATABASE

Expand Down

0 comments on commit 19a57f6

Please sign in to comment.