Skip to content

Commit

Permalink
plumbing + concurrency test
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Oct 19, 2023
1 parent 8dad8e0 commit 2bd9df1
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 19 deletions.
3 changes: 3 additions & 0 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ def max_batch_size(self) -> int:


class ClientAPI(BaseAPI, ABC):
tenant: str
database: str

@abstractmethod
def set_database(self, database: str) -> None:
"""Set the database for the client.
Expand Down
42 changes: 40 additions & 2 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, Dict, Optional, Sequence
from typing import ClassVar, Dict, Optional, Sequence, TypeVar
from uuid import UUID

from overrides import override
Expand All @@ -22,6 +22,8 @@
from chromadb.types import Where, WhereDocument
import chromadb.utils.embedding_functions as ef

C = TypeVar("C", "SharedSystemClient", "Client", "AdminClient")


class SharedSystemClient:
_identifer_to_system: ClassVar[Dict[str, System]] = {}
Expand Down Expand Up @@ -81,6 +83,20 @@ def _get_identifier_from_settings(settings: Settings) -> str:

return identifier

@staticmethod
def _populate_data_from_system(system: System) -> str:
identifier = SharedSystemClient._get_identifier_from_settings(system.settings)
SharedSystemClient._identifer_to_system[identifier] = system
return identifier

@classmethod
def from_system(cls, system: System) -> "SharedSystemClient":
"""Create a client from an existing system. This is useful for testing and debugging."""

SharedSystemClient._populate_data_from_system(system)
instance = cls(system.settings)
return instance

@staticmethod
def clear_system_cache() -> None:
SharedSystemClient._identifer_to_system = {}
Expand Down Expand Up @@ -125,6 +141,18 @@ def __init__(
telemetry_client = self._system.instance(Telemetry)
telemetry_client.capture(ClientStartEvent())

@classmethod
@override
def from_system(
cls,
system: System,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "Client":
SharedSystemClient._populate_data_from_system(system)
instance = cls(tenant=tenant, database=database, settings=system.settings)
return instance

# endregion

# region BaseAPI Methods
Expand Down Expand Up @@ -366,7 +394,7 @@ def set_tenant(self, tenant: str) -> None:
# endregion


class AdminClient(AdminAPI, SharedSystemClient):
class AdminClient(SharedSystemClient, AdminAPI):
_server: ServerAPI

def __init__(self, settings: Settings = Settings()) -> None:
Expand All @@ -376,3 +404,13 @@ def __init__(self, settings: Settings = Settings()) -> None:
@override
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
return self._server.create_database(name=name, tenant=tenant)

@classmethod
@override
def from_system(
cls,
system: System,
) -> "AdminClient":
SharedSystemClient._populate_data_from_system(system)
instance = cls(settings=system.settings)
return instance
25 changes: 20 additions & 5 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def create_database(
"""Creates a database"""
resp = self._session.post(
self._api_url + "/databases",
data=json.dumps({"name": name, "tenant": tenant}),
data=json.dumps({"name": name}),
params={"tenant": tenant},
)
raise_chroma_error(resp)

Expand All @@ -152,7 +153,10 @@ def list_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> Sequence[Collection]:
"""Returns a list of all collections"""
resp = self._session.get(self._api_url + "/collections")
resp = self._session.get(
self._api_url + "/collections",
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)
json_collections = resp.json()
collections = []
Expand All @@ -175,8 +179,13 @@ def create_collection(
resp = self._session.post(
self._api_url + "/collections",
data=json.dumps(
{"name": name, "metadata": metadata, "get_or_create": get_or_create}
{
"name": name,
"metadata": metadata,
"get_or_create": get_or_create,
}
),
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)
resp_json = resp.json()
Expand All @@ -197,7 +206,10 @@ def get_collection(
database: str = DEFAULT_DATABASE,
) -> Collection:
"""Returns a collection"""
resp = self._session.get(self._api_url + "/collections/" + name)
resp = self._session.get(
self._api_url + "/collections/" + name,
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)
resp_json = resp.json()
return Collection(
Expand Down Expand Up @@ -248,7 +260,10 @@ def delete_collection(
database: str = DEFAULT_DATABASE,
) -> None:
"""Deletes a collection"""
resp = self._session.delete(self._api_url + "/collections/" + name)
resp = self._session.delete(
self._api_url + "/collections/" + name,
params={"tenant": tenant, "database": database},
)
raise_chroma_error(resp)

@override
Expand Down
4 changes: 3 additions & 1 deletion chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def delete_collection(
)

if existing:
self._sysdb.delete_collection(existing[0]["id"])
self._sysdb.delete_collection(
existing[0]["id"], tenant=tenant, database=database
)
for s in self._manager.delete_segments(existing[0]["id"]):
self._sysdb.delete_segment(s)
if existing and existing[0]["id"] in self._collection_cache:
Expand Down
55 changes: 45 additions & 10 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
FastAPIChromaAuthMiddleware,
FastAPIChromaAuthMiddlewareWrapper,
)
from chromadb.config import Settings, System
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
import chromadb.server
import chromadb.api
from chromadb.api import ServerAPI
Expand All @@ -26,6 +26,7 @@
)
from chromadb.server.fastapi.types import (
AddEmbedding,
CreateDatabase,
DeleteEmbedding,
GetEmbedding,
QueryEmbedding,
Expand Down Expand Up @@ -100,7 +101,6 @@ def include_in_schema(path: str) -> bool:
super().add_api_route(path, *args, **kwargs)


# TODO: add tenant/namespace to all routes
class FastAPI(chromadb.server.Server):
def __init__(self, settings: Settings):
super().__init__(settings)
Expand Down Expand Up @@ -134,6 +134,13 @@ def __init__(self, settings: Settings):
"/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"]
)

self.router.add_api_route(
"/api/v1/databases",
self.create_database,
methods=["POST"],
response_model=None,
)

self.router.add_api_route(
"/api/v1/collections",
self.list_collections,
Expand Down Expand Up @@ -225,18 +232,39 @@ def heartbeat(self) -> Dict[str, int]:
def version(self) -> str:
return self._api.get_version()

def list_collections(self) -> Sequence[Collection]:
return self._api.list_collections()

def create_collection(self, collection: CreateCollection) -> Collection:
def create_database(
self, database: CreateDatabase, tenant: str = DEFAULT_TENANT
) -> None:
return self._api.create_database(database.name, tenant)

def list_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> Sequence[Collection]:
return self._api.list_collections(tenant=tenant, database=database)

def create_collection(
self,
collection: CreateCollection,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
return self._api.create_collection(
name=collection.name,
metadata=collection.metadata,
get_or_create=collection.get_or_create,
tenant=tenant,
database=database,
)

def get_collection(self, collection_name: str) -> Collection:
return self._api.get_collection(collection_name)
def get_collection(
self,
collection_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
return self._api.get_collection(
collection_name, tenant=tenant, database=database
)

def update_collection(
self, collection_id: str, collection: UpdateCollection
Expand All @@ -247,8 +275,15 @@ def update_collection(
new_metadata=collection.new_metadata,
)

def delete_collection(self, collection_name: str) -> None:
return self._api.delete_collection(collection_name)
def delete_collection(
self,
collection_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
return self._api.delete_collection(
collection_name, tenant=tenant, database=database
)

def add(self, collection_id: str, add: AddEmbedding) -> None:
try:
Expand Down
4 changes: 4 additions & 0 deletions chromadb/server/fastapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ class CreateCollection(BaseModel): # type: ignore
class UpdateCollection(BaseModel): # type: ignore
new_name: Optional[str] = None
new_metadata: Optional[CollectionMetadata] = None


class CreateDatabase(BaseModel):
name: str
60 changes: 60 additions & 0 deletions chromadb/test/client/test_database_tenant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from chromadb.api.client import AdminClient, Client


def test_database_tenant_collections(client: Client) -> None:
# Create a new database in the default tenant
admin_client = AdminClient.from_system(client._system)
admin_client.create_database("test_db")

# Create collections in this new database
client.set_database("test_db")
client.create_collection("collection", metadata={"database": "test_db"})

# Create collections in the default database
client.set_database("default")
client.create_collection("collection", metadata={"database": "default"})

# List collections in the default database
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].name == "collection"
assert collections[0].metadata == {"database": "default"}

# List collections in the new database
client.set_database("test_db")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "test_db"}

# Update the metadata in both databases to different values
client.set_database("default")
client.list_collections()[0].modify(metadata={"database": "default2"})

client.set_database("test_db")
client.list_collections()[0].modify(metadata={"database": "test_db2"})

# Validate that the metadata was updated
client.set_database("default")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "default2"}

client.set_database("test_db")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "test_db2"}

# Delete the collections and make sure databases are isolated
client.set_database("default")
client.delete_collection("collection")

collections = client.list_collections()
assert len(collections) == 0

client.set_database("test_db")
collections = client.list_collections()
assert len(collections) == 1

client.delete_collection("collection")
collections = client.list_collections()
assert len(collections) == 0
43 changes: 43 additions & 0 deletions chromadb/test/client/test_multiple_clients_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from concurrent.futures import ThreadPoolExecutor
from chromadb.api.client import AdminClient, Client


def test_multiple_clients_concurrently(client: Client) -> None:
"""Tests running multiple clients, each against their own database, concurrently."""
admin_client = AdminClient.from_system(client._system)
admin_client.create_database("test_db")

CLIENT_COUNT = 100
COLLECTION_COUNT = 500

# Each database will create the same collections by name, with differing metadata
databases = [f"db{i}" for i in range(CLIENT_COUNT)]
for database in databases:
admin_client.create_database(database)

collections = [f"collection{i}" for i in range(COLLECTION_COUNT)]

# Create N clients, each on a seperate thread, each with their own database
def run_target(n: int) -> None:
thread_client = Client(
tenant="default", database=databases[n], settings=client._system.settings
)
for collection in collections:
thread_client.create_collection(
collection, metadata={"database": databases[n]}
)

with ThreadPoolExecutor(max_workers=CLIENT_COUNT) as executor:
executor.map(run_target, range(CLIENT_COUNT))

# Create a final client, which will be used to verify the collections were created
client = Client(settings=client._system.settings)

# Verify that the collections were created
for database in databases:
client.set_database(database)
seen_collections = client.list_collections()
assert len(seen_collections) == COLLECTION_COUNT
for collection in seen_collections:
assert collection.name in collections
assert collection.metadata == {"database": database}
Loading

0 comments on commit 2bd9df1

Please sign in to comment.