Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Oct 19, 2023
1 parent 0321fe8 commit 8dad8e0
Show file tree
Hide file tree
Showing 16 changed files with 649 additions and 284 deletions.
130 changes: 13 additions & 117 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,19 @@ def clear_system_cache() -> None:
pass


class ServerAPI(BaseAPI, Component):
class AdminAPI(ABC):
@abstractmethod
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
"""Create a new database.
Args:
database: The name of the database to create.
"""
pass


class ServerAPI(BaseAPI, AdminAPI, Component):
"""An API instance that extends the relevant Base API methods by passing
in a tenant and database. This is the root component of the Chroma System"""

Expand Down Expand Up @@ -473,8 +485,6 @@ def _modify(
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
pass

Expand All @@ -487,117 +497,3 @@ def delete_collection(
database: str = DEFAULT_DATABASE,
) -> None:
pass

#
# ITEM METHODS
#

@abstractmethod
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
pass

@abstractmethod
@override
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
pass

@abstractmethod
@override
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
pass

@abstractmethod
@override
def _count(
self,
collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
pass

@abstractmethod
@override
def _peek(
self,
collection_id: UUID,
n: int = 10,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
pass

@abstractmethod
@override
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = {},
sort: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
page: Optional[int] = None,
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"],
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
pass

@abstractmethod
@override
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs],
where: Optional[Where] = {},
where_document: Optional[WhereDocument] = {},
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> IDs:
pass

@abstractmethod
@override
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
n_results: int = 10,
where: Where = {},
where_document: WhereDocument = {},
include: Include = ["embeddings", "metadatas", "documents", "distances"],
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
pass
114 changes: 55 additions & 59 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from overrides import override
from chromadb.api import ClientAPI, ServerAPI
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
from chromadb.api.types import (
CollectionMetadata,
Documents,
Expand All @@ -23,51 +23,17 @@
import chromadb.utils.embedding_functions as ef


class Client(ClientAPI):
"""A client for Chroma. This is the main entrypoint for interacting with Chroma.
A client internally stores its tenant and database and proxies calls to a
Server API instance of Chroma. It treats the Server API and corresponding System
as a singleton, so multiple clients connecting to the same resource will share the
same API instance.
Client implementations should be implement their own API-caching strategies.
"""

tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE

class SharedSystemClient:
_identifer_to_system: ClassVar[Dict[str, System]] = {}
_identifier: str
_server: ServerAPI

# region Initialization
def __new__(
cls,
tenant: str = "default",
database: str = "default",
settings: Settings = Settings(),
) -> "Client":
identifier = cls._get_identifier_from_settings(settings)
cls._create_system_if_not_exists(identifier, settings)
instance = super().__new__(cls)
return instance

def __init__(
self,
tenant: str = "default",
database: str = "default",
settings: Settings = Settings(),
) -> None:
self.tenant = tenant
self.database = database
self._identifier = self._get_identifier_from_settings(settings)

# Get the root system component we want to interact with
self._server = self._system.instance(ServerAPI)

# Submit event for a client start
telemetry_client = self._system.instance(Telemetry)
telemetry_client.capture(ClientStartEvent())
self._identifier = SharedSystemClient._get_identifier_from_settings(settings)
SharedSystemClient._create_system_if_not_exists(self._identifier, settings)

@classmethod
def _create_system_if_not_exists(
Expand Down Expand Up @@ -116,13 +82,48 @@ def _get_identifier_from_settings(settings: Settings) -> str:
return identifier

@staticmethod
@override
def clear_system_cache() -> None:
Client._identifer_to_system = {}
SharedSystemClient._identifer_to_system = {}

@property
def _system(self) -> System:
return self._identifer_to_system[self._identifier]
return SharedSystemClient._identifer_to_system[self._identifier]

# endregion


class Client(SharedSystemClient, ClientAPI):
"""A client for Chroma. This is the main entrypoint for interacting with Chroma.
A client internally stores its tenant and database and proxies calls to a
Server API instance of Chroma. It treats the Server API and corresponding System
as a singleton, so multiple clients connecting to the same resource will share the
same API instance.
Client implementations should be implement their own API-caching strategies.
"""

tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE

_server: ServerAPI

# region Initialization
def __init__(
self,
tenant: str = "default",
database: str = "default",
settings: Settings = Settings(),
) -> None:
super().__init__(settings=settings)
self.tenant = tenant
self.database = database

# Get the root system component we want to interact with
self._server = self._system.instance(ServerAPI)

# Submit event for a client start
telemetry_client = self._system.instance(Telemetry)
telemetry_client.capture(ClientStartEvent())

# endregion

Expand Down Expand Up @@ -150,6 +151,7 @@ def create_collection(
embedding_function=embedding_function,
tenant=self.tenant,
database=self.database,
get_or_create=get_or_create,
)

@override
Expand Down Expand Up @@ -191,8 +193,6 @@ def _modify(
id=id,
new_name=new_name,
new_metadata=new_metadata,
tenant=self.tenant,
database=self.database,
)

@override
Expand Down Expand Up @@ -225,8 +225,6 @@ def _add(
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
tenant=self.tenant,
database=self.database,
)

@override
Expand All @@ -244,8 +242,6 @@ def _update(
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
tenant=self.tenant,
database=self.database,
)

@override
Expand All @@ -263,25 +259,19 @@ def _upsert(
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
tenant=self.tenant,
database=self.database,
)

@override
def _count(self, collection_id: UUID) -> int:
return self._server._count(
collection_id=collection_id,
tenant=self.tenant,
database=self.database,
)

@override
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
return self._server._peek(
collection_id=collection_id,
n=n,
tenant=self.tenant,
database=self.database,
)

@override
Expand Down Expand Up @@ -309,8 +299,6 @@ def _get(
page_size=page_size,
where_document=where_document,
include=include,
tenant=self.tenant,
database=self.database,
)

def _delete(
Expand All @@ -325,8 +313,6 @@ def _delete(
ids=ids,
where=where,
where_document=where_document,
tenant=self.tenant,
database=self.database,
)

@override
Expand All @@ -346,8 +332,6 @@ def _query(
where=where,
where_document=where_document,
include=include,
tenant=self.tenant,
database=self.database,
)

@override
Expand Down Expand Up @@ -380,3 +364,15 @@ def set_tenant(self, tenant: str) -> None:
self.tenant = tenant

# endregion


class AdminClient(AdminAPI, SharedSystemClient):
_server: ServerAPI

def __init__(self, settings: Settings = Settings()) -> None:
super().__init__(settings)
self._server = self._system.instance(ServerAPI)

@override
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
return self._server.create_database(name=name, tenant=tenant)
Loading

0 comments on commit 8dad8e0

Please sign in to comment.