diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 9c0b8000a14..8360a69498c 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -1,8 +1,9 @@ from typing import Dict import logging +from chromadb.api.client import Client as ClientCreator import chromadb.config -from chromadb.config import Settings, System -from chromadb.api import API +from chromadb.config import Settings +from chromadb.api import ClientAPI from chromadb.api.models.Collection import Collection from chromadb.api.types import ( CollectionMetadata, @@ -35,8 +36,6 @@ "QueryResult", "GetResult", ] -from chromadb.telemetry.events import ClientStartEvent -from chromadb.telemetry import Telemetry logger = logging.getLogger(__name__) @@ -55,13 +54,15 @@ is_client = False try: - from chromadb.is_thin_client import is_thin_client # type: ignore + from chromadb.is_thin_client import is_thin_client + is_client = is_thin_client except ImportError: is_client = False if not is_client: import sqlite3 + if sqlite3.sqlite_version_info < (3, 35, 0): if IN_COLAB: # In Colab, hotswap to pysqlite-binary if it's too old @@ -90,7 +91,7 @@ def get_settings() -> Settings: return __settings -def EphemeralClient(settings: Settings = Settings()) -> API: +def EphemeralClient(settings: Settings = Settings()) -> ClientAPI: """ Creates an in-memory instance of Chroma. This is useful for testing and development, but not recommended for production use. @@ -100,7 +101,12 @@ def EphemeralClient(settings: Settings = Settings()) -> API: return Client(settings) -def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) -> API: +def PersistentClient( + path: str = "./chroma", + tenant: str = "default", + database: str = "default", + settings: Settings = Settings(), +) -> ClientAPI: """ Creates a persistent instance of Chroma that saves to disk. This is useful for testing and development, but not recommended for production use. @@ -111,7 +117,7 @@ def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) -> settings.persist_directory = path settings.is_persistent = True - return Client(settings) + return ClientCreator(tenant=tenant, database=database, settings=settings) def HttpClient( @@ -119,8 +125,10 @@ def HttpClient( port: str = "8000", ssl: bool = False, headers: Dict[str, str] = {}, + tenant: str = "default", + database: str = "default", settings: Settings = Settings(), -) -> API: +) -> ClientAPI: """ Creates a client that connects to a remote Chroma server. This supports many clients connecting to the same server, and is the recommended way to @@ -139,20 +147,47 @@ def HttpClient( settings.chroma_server_ssl_enabled = ssl settings.chroma_server_headers = headers - return Client(settings) + return ClientCreator(tenant=tenant, database=database, settings=settings) -def Client(settings: Settings = __settings) -> API: +# TODO: replace default tenant and database strings with constants +def Client( + settings: Settings = __settings, tenant: str = "default", database: str = "default" +) -> ClientAPI: """Return a running chroma.API instance""" - system = System(settings) - - telemetry_client = system.instance(Telemetry) - api = system.instance(API) - - system.start() - - # Submit event for client start - telemetry_client.capture(ClientStartEvent()) - - return api + # Change this to actually check if an "API" instance already exists, wrap it in a + # tenant/database aware "Client", and return it + # this way we can support multiple clients in the same process but using the same + # chroma instance + + # API is thread safe, so we can just return the same instance + # This way a "Client" will just be a wrapper around an API instance that is + # tenant/database aware + + # To do this we will + # 1. Have a global dict of API instances, keyed by path + # 2. When a client is requested, check if one exists in the dict, and if so check if its + # settings match the requested settings + # 3. If the settings match, construct a new Client that wraps the existing API instance with + # the tenant/database + # 4. If the settings don't match, error out because we don't support changing the settings + # got a given database + # 5. If no client exists in the dict, create a new API instance, wrap it in a Client, and + # add it to the dict + + # The hierarchy then becomes + # For local + # Path -> Tenant -> Namespace -> API + # For remote + # Host -> Tenant -> Namespace -> API + + # A given API for a path is a singleton, and is shared between all tenants and namespaces + # for that path + + # A DB exists at a path or host, and has tenants and namespaces + + # All our tests currently use system.instance(API) assuming thats the root object + # This is likely fine, + + return ClientCreator(tenant=tenant, database=database, settings=settings) diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 50f2ff1ecef..2287bd929d2 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from typing import Sequence, Optional from uuid import UUID + +from overrides import override +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT from chromadb.api.models.Collection import Collection from chromadb.api.types import ( CollectionMetadata, @@ -19,7 +22,7 @@ import chromadb.utils.embedding_functions as ef -class API(Component, ABC): +class BaseAPI(ABC): @abstractmethod def heartbeat(self) -> int: """Get the current time in nanoseconds since epoch. @@ -371,10 +374,10 @@ def get_version(self) -> str: @abstractmethod def get_settings(self) -> Settings: - """Get the settings used to initialize the client. + """Get the settings used to initialize. Returns: - Settings: The settings used to initialize the client. + Settings: The settings used to initialize. """ pass @@ -385,3 +388,216 @@ def max_batch_size(self) -> int: """Return the maximum number of records that can be submitted in a single call to submit_embeddings.""" pass + + +class ClientAPI(BaseAPI, ABC): + @abstractmethod + def set_database(self, database: str) -> None: + """Set the database for the client. + + Args: + database: The database to set. + + """ + pass + + @abstractmethod + def set_tenant(self, tenant: str) -> None: + """Set the tenant for the client. + + Args: + tenant: The tenant to set. + + """ + pass + + @staticmethod + @abstractmethod + def clear_system_cache() -> None: + """Clear the system cache so that new systems can be created for an existing path. + This should only be used for testing purposes.""" + pass + + +class ServerAPI(BaseAPI, Component): + """An API instance that extends the relevant Base API methods by passing + in a tenant and database. This is the root component of the Chroma System""" + + @abstractmethod + @override + def list_collections( + self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> Sequence[Collection]: + pass + + @abstractmethod + @override + def create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + pass + + @abstractmethod + @override + def get_collection( + self, + name: str, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + pass + + @abstractmethod + @override + def get_or_create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + pass + + @abstractmethod + @override + def _modify( + self, + id: UUID, + new_name: Optional[str] = None, + new_metadata: Optional[CollectionMetadata] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: + pass + + @abstractmethod + @override + def delete_collection( + self, + name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: + pass + + # + # ITEM METHODS + # + + @abstractmethod + @override + def _add( + self, + ids: IDs, + collection_id: UUID, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + pass + + @abstractmethod + @override + def _update( + self, + collection_id: UUID, + ids: IDs, + embeddings: Optional[Embeddings] = None, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + pass + + @abstractmethod + @override + def _upsert( + self, + collection_id: UUID, + ids: IDs, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + pass + + @abstractmethod + @override + def _count( + self, + collection_id: UUID, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> int: + pass + + @abstractmethod + @override + def _peek( + self, + collection_id: UUID, + n: int = 10, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> GetResult: + pass + + @abstractmethod + @override + def _get( + self, + collection_id: UUID, + ids: Optional[IDs] = None, + where: Optional[Where] = {}, + sort: Optional[str] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + page: Optional[int] = None, + page_size: Optional[int] = None, + where_document: Optional[WhereDocument] = {}, + include: Include = ["embeddings", "metadatas", "documents"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> GetResult: + pass + + @abstractmethod + @override + def _delete( + self, + collection_id: UUID, + ids: Optional[IDs], + where: Optional[Where] = {}, + where_document: Optional[WhereDocument] = {}, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> IDs: + pass + + @abstractmethod + @override + def _query( + self, + collection_id: UUID, + query_embeddings: Embeddings, + n_results: int = 10, + where: Where = {}, + where_document: WhereDocument = {}, + include: Include = ["embeddings", "metadatas", "documents", "distances"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> QueryResult: + pass diff --git a/chromadb/api/client.py b/chromadb/api/client.py new file mode 100644 index 00000000000..c007495705b --- /dev/null +++ b/chromadb/api/client.py @@ -0,0 +1,382 @@ +from typing import ClassVar, Dict, Optional, Sequence +from uuid import UUID + +from overrides import override +from chromadb.api import ClientAPI, ServerAPI +from chromadb.api.types import ( + CollectionMetadata, + Documents, + EmbeddingFunction, + Embeddings, + GetResult, + IDs, + Include, + Metadatas, + QueryResult, +) +from chromadb.config import Settings, System +from chromadb.telemetry import Telemetry +from chromadb.telemetry.events import ClientStartEvent +from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE +from chromadb.api.models.Collection import Collection +from chromadb.types import Where, WhereDocument +import chromadb.utils.embedding_functions as ef + + +class Client(ClientAPI): + """A client for Chroma. This is the main entrypoint for interacting with Chroma. + A client internally stores its tenant and database and proxies calls to a + Server API instance of Chroma. It treats the Server API and corresponding System + as a singleton, so multiple clients connecting to the same resource will share the + same API instance. + + Client implementations should be implement their own API-caching strategies. + """ + + tenant: str = DEFAULT_TENANT + database: str = DEFAULT_DATABASE + + _identifer_to_system: ClassVar[Dict[str, System]] = {} + _identifier: str + _server: ServerAPI + + # region Initialization + def __new__( + cls, + tenant: str = "default", + database: str = "default", + settings: Settings = Settings(), + ) -> "Client": + identifier = cls._get_identifier_from_settings(settings) + cls._create_system_if_not_exists(identifier, settings) + instance = super().__new__(cls) + return instance + + def __init__( + self, + tenant: str = "default", + database: str = "default", + settings: Settings = Settings(), + ) -> None: + self.tenant = tenant + self.database = database + self._identifier = self._get_identifier_from_settings(settings) + + # Get the root system component we want to interact with + self._server = self._system.instance(ServerAPI) + + # Submit event for a client start + telemetry_client = self._system.instance(Telemetry) + telemetry_client.capture(ClientStartEvent()) + + @classmethod + def _create_system_if_not_exists( + cls, identifier: str, settings: Settings + ) -> System: + if identifier not in cls._identifer_to_system: + new_system = System(settings) + cls._identifer_to_system[identifier] = new_system + + new_system.instance(Telemetry) + new_system.instance(ServerAPI) + + new_system.start() + else: + previous_system = cls._identifer_to_system[identifier] + + # For now, the settings must match + if previous_system.settings != settings: + raise ValueError( + f"An instance of Chroma already exists for {identifier} with different settings" + ) + + return cls._identifer_to_system[identifier] + + @staticmethod + def _get_identifier_from_settings(settings: Settings) -> str: + identifier = "" + api_impl = settings.chroma_api_impl + + if api_impl is None: + raise ValueError("Chroma API implementation must be set in settings") + elif api_impl == "chromadb.api.segment.SegmentAPI": + if settings.is_persistent: + identifier = settings.persist_directory + else: + identifier = ( + "ephemeral" # TODO: support pathing and multiple ephemeral clients + ) + elif api_impl == "chromadb.api.fastapi.FastAPI": + identifier = ( + f"{settings.chroma_server_host}:{settings.chroma_server_http_port}" + ) + else: + raise ValueError(f"Unsupported Chroma API implementation {api_impl}") + + return identifier + + @staticmethod + @override + def clear_system_cache() -> None: + Client._identifer_to_system = {} + + @property + def _system(self) -> System: + return self._identifer_to_system[self._identifier] + + # endregion + + # region BaseAPI Methods + # Note - we could do this in less verbose ways, but they break type checking + @override + def heartbeat(self) -> int: + return self._server.heartbeat() + + @override + def list_collections(self) -> Sequence[Collection]: + return self._server.list_collections(tenant=self.tenant, database=self.database) + + @override + def create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + get_or_create: bool = False, + ) -> Collection: + return self._server.create_collection( + name=name, + metadata=metadata, + embedding_function=embedding_function, + tenant=self.tenant, + database=self.database, + ) + + @override + def get_collection( + self, + name: str, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + ) -> Collection: + return self._server.get_collection( + name=name, + embedding_function=embedding_function, + tenant=self.tenant, + database=self.database, + ) + + @override + def get_or_create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + ) -> Collection: + return self._server.get_or_create_collection( + name=name, + metadata=metadata, + embedding_function=embedding_function, + tenant=self.tenant, + database=self.database, + ) + + @override + def _modify( + self, + id: UUID, + new_name: Optional[str] = None, + new_metadata: Optional[CollectionMetadata] = None, + ) -> None: + return self._server._modify( + id=id, + new_name=new_name, + new_metadata=new_metadata, + tenant=self.tenant, + database=self.database, + ) + + @override + def delete_collection( + self, + name: str, + ) -> None: + return self._server.delete_collection( + name=name, + tenant=self.tenant, + database=self.database, + ) + + # + # ITEM METHODS + # + + @override + def _add( + self, + ids: IDs, + collection_id: UUID, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + ) -> bool: + return self._server._add( + ids=ids, + collection_id=collection_id, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + tenant=self.tenant, + database=self.database, + ) + + @override + def _update( + self, + collection_id: UUID, + ids: IDs, + embeddings: Optional[Embeddings] = None, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + ) -> bool: + return self._server._update( + collection_id=collection_id, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + tenant=self.tenant, + database=self.database, + ) + + @override + def _upsert( + self, + collection_id: UUID, + ids: IDs, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + ) -> bool: + return self._server._upsert( + collection_id=collection_id, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + tenant=self.tenant, + database=self.database, + ) + + @override + def _count(self, collection_id: UUID) -> int: + return self._server._count( + collection_id=collection_id, + tenant=self.tenant, + database=self.database, + ) + + @override + def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: + return self._server._peek( + collection_id=collection_id, + n=n, + tenant=self.tenant, + database=self.database, + ) + + @override + def _get( + self, + collection_id: UUID, + ids: Optional[IDs] = None, + where: Optional[Where] = {}, + sort: Optional[str] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + page: Optional[int] = None, + page_size: Optional[int] = None, + where_document: Optional[WhereDocument] = {}, + include: Include = ["embeddings", "metadatas", "documents"], + ) -> GetResult: + return self._server._get( + collection_id=collection_id, + ids=ids, + where=where, + sort=sort, + limit=limit, + offset=offset, + page=page, + page_size=page_size, + where_document=where_document, + include=include, + tenant=self.tenant, + database=self.database, + ) + + def _delete( + self, + collection_id: UUID, + ids: Optional[IDs], + where: Optional[Where] = {}, + where_document: Optional[WhereDocument] = {}, + ) -> IDs: + return self._server._delete( + collection_id=collection_id, + ids=ids, + where=where, + where_document=where_document, + tenant=self.tenant, + database=self.database, + ) + + @override + def _query( + self, + collection_id: UUID, + query_embeddings: Embeddings, + n_results: int = 10, + where: Where = {}, + where_document: WhereDocument = {}, + include: Include = ["embeddings", "metadatas", "documents", "distances"], + ) -> QueryResult: + return self._server._query( + collection_id=collection_id, + query_embeddings=query_embeddings, + n_results=n_results, + where=where, + where_document=where_document, + include=include, + tenant=self.tenant, + database=self.database, + ) + + @override + def reset(self) -> bool: + return self._server.reset() + + @override + def get_version(self) -> str: + return self._server.get_version() + + @override + def get_settings(self) -> Settings: + return self._server.get_settings() + + @property + @override + def max_batch_size(self) -> int: + return self._server.max_batch_size + + # endregion + + # region ClientAPI Methods + + @override + def set_database(self, database: str) -> None: + self.database = database + + @override + def set_tenant(self, tenant: str) -> None: + self.tenant = tenant + + # endregion diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 2ddd537ebff..2ee75612c23 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -9,7 +9,7 @@ import chromadb.errors as errors import chromadb.utils.embedding_functions as ef -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection from chromadb.api.types import ( Documents, @@ -30,14 +30,14 @@ ) from chromadb.auth.providers import RequestsClientAuthProtocolAdapter from chromadb.auth.registry import resolve_provider -from chromadb.config import Settings, System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.telemetry import Telemetry from urllib.parse import urlparse, urlunparse, quote logger = logging.getLogger(__name__) -class FastAPI(API): +class FastAPI(ServerAPI): _settings: Settings _max_batch_size: int = -1 @@ -135,7 +135,9 @@ def heartbeat(self) -> int: return int(resp.json()["nanosecond heartbeat"]) @override - def list_collections(self) -> Sequence[Collection]: + def list_collections( + self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> Sequence[Collection]: """Returns a list of all collections""" resp = self._session.get(self._api_url + "/collections") raise_chroma_error(resp) @@ -153,6 +155,8 @@ def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: """Creates a collection""" resp = self._session.post( @@ -176,6 +180,8 @@ def get_collection( self, name: str, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: """Returns a collection""" resp = self._session.get(self._api_url + "/collections/" + name) @@ -195,6 +201,8 @@ def get_or_create_collection( name: str, metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: return self.create_collection( name, metadata, embedding_function, get_or_create=True @@ -206,6 +214,8 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> None: """Updates a collection""" resp = self._session.put( @@ -215,13 +225,23 @@ def _modify( raise_chroma_error(resp) @override - def delete_collection(self, name: str) -> None: + def delete_collection( + self, + name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: """Deletes a collection""" resp = self._session.delete(self._api_url + "/collections/" + name) raise_chroma_error(resp) @override - def _count(self, collection_id: UUID) -> int: + def _count( + self, + collection_id: UUID, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> int: """Returns the number of embeddings in the database""" resp = self._session.get( self._api_url + "/collections/" + str(collection_id) + "/count" @@ -230,7 +250,13 @@ def _count(self, collection_id: UUID) -> int: return cast(int, resp.json()) @override - def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: + def _peek( + self, + collection_id: UUID, + n: int = 10, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> GetResult: return self._get( collection_id, limit=n, @@ -250,6 +276,8 @@ def _get( page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> GetResult: if page and page_size: offset = (page - 1) * page_size @@ -286,6 +314,8 @@ def _delete( ids: Optional[IDs] = None, where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> IDs: """Deletes embeddings from the database""" resp = self._session.post( @@ -329,6 +359,8 @@ def _add( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: """ Adds a batch of embeddings to the database @@ -348,6 +380,8 @@ def _update( embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: """ Updates a batch of embeddings in the database @@ -369,6 +403,8 @@ def _upsert( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: """ Upserts a batch of embeddings in the database @@ -391,6 +427,8 @@ def _query( where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents", "distances"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> QueryResult: """Gets the nearest neighbors of a single embedding""" resp = self._session.post( diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index c11a04b1fa4..d1f4b296712 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -33,19 +33,21 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from chromadb.api import API + from chromadb.api import ServerAPI class Collection(BaseModel): name: str id: UUID metadata: Optional[CollectionMetadata] = None - _client: "API" = PrivateAttr() + _client: "ServerAPI" = PrivateAttr() _embedding_function: Optional[EmbeddingFunction] = PrivateAttr() + # TODO: STORE THE TENANT AND NAMESPACE IN THE COLLECTION OBJECT + def __init__( self, - client: "API", + client: "ServerAPI", name: str, id: UUID, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index cfe1300e76e..27b53ea41ef 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -1,5 +1,5 @@ -from chromadb.api import API -from chromadb.config import Settings, System +from chromadb.api import ServerAPI +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.db.system import SysDB from chromadb.segment import SegmentManager, MetadataReader, VectorReader from chromadb.telemetry import Telemetry @@ -71,7 +71,7 @@ def check_index_name(index_name: str) -> None: raise ValueError(msg) -class SegmentAPI(API): +class SegmentAPI(ServerAPI): """API implementation utilizing the new segment-based internal architecture""" _settings: Settings @@ -104,6 +104,8 @@ def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: if metadata is not None: validate_metadata(metadata) @@ -148,6 +150,8 @@ def get_or_create_collection( name: str, metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: return self.create_collection( name=name, @@ -164,6 +168,8 @@ def get_collection( self, name: str, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: existing = self._sysdb.get_collections(name=name) @@ -179,7 +185,11 @@ def get_collection( raise ValueError(f"Collection {name} does not exist.") @override - def list_collections(self) -> Sequence[Collection]: + def list_collections( + self, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Sequence[Collection]: collections = [] db_collections = self._sysdb.get_collections() for db_collection in db_collections: @@ -199,6 +209,8 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> None: if new_name: # backwards compatibility in naming requirements (for now) @@ -217,7 +229,12 @@ def _modify( self._sysdb.update_collection(id, metadata=new_metadata) @override - def delete_collection(self, name: str) -> None: + def delete_collection( + self, + name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: existing = self._sysdb.get_collections(name=name) if existing: @@ -237,6 +254,8 @@ def _add( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) @@ -274,6 +293,8 @@ def _update( embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) @@ -313,6 +334,8 @@ def _upsert( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) @@ -347,6 +370,8 @@ def _get( page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> GetResult: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( @@ -414,6 +439,8 @@ def _delete( ids: Optional[IDs] = None, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> IDs: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( @@ -469,7 +496,12 @@ def _delete( return ids_to_delete @override - def _count(self, collection_id: UUID) -> int: + def _count( + self, + collection_id: UUID, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> int: metadata_segment = self._manager.get_segment(collection_id, MetadataReader) return metadata_segment.count() @@ -482,6 +514,8 @@ def _query( where: Where = {}, where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> QueryResult: where = validate_where(where) if where is not None and len(where) > 0 else where where_document = ( @@ -574,7 +608,13 @@ def _query( ) @override - def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: + def _peek( + self, + collection_id: UUID, + n: int = 10, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> GetResult: return self._get(collection_id, limit=n) @override diff --git a/chromadb/config.py b/chromadb/config.py index eb7bca93ef5..eb575696049 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -63,7 +63,8 @@ # TODO: Don't use concrete types here to avoid circular deps. Strings are fine for right here! _abstract_type_keys: Dict[str, str] = { - "chromadb.api.API": "chroma_api_impl", + "chromadb.api.API": "chroma_api_impl", # NOTE: this is to support legacy api construction. Use ServerAPI instead + "chromadb.api.ServerAPI": "chroma_api_impl", "chromadb.telemetry.Telemetry": "chroma_telemetry_impl", "chromadb.ingest.Producer": "chroma_producer_impl", "chromadb.ingest.Consumer": "chroma_consumer_impl", @@ -74,6 +75,9 @@ "chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl", } +DEFAULT_TENANT = "default" +DEFAULT_DATABASE = "default" + class Settings(BaseSettings): # type: ignore environment: str = "" diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql new file mode 100644 index 00000000000..1c40e823480 --- /dev/null +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -0,0 +1,19 @@ +CREATE TABLE tenants ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + UNIQUE (name) -- Maybe not needed since we want to support slug ids +); + +CREATE TABLE databases ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + tenant_id TEXT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, + UNIQUE (name) +); + +ALTER TABLE collections + ADD COLUMN database_id TEXT NOT NULL REFERENCES databases(id) DEFAULT 'default'; -- ON DELETE CASCADE not supported by sqlite in ALTER TABLE + +-- Create default tenant and database +INSERT INTO tenants (id, name) VALUES ('default', 'default'); +INSERT INTO databases (id, name, tenant_id) VALUES ('default', 'default', 'default'); diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index e92d16d63ba..ec8bcac7ea1 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -15,9 +15,10 @@ FastAPIChromaAuthMiddleware, FastAPIChromaAuthMiddlewareWrapper, ) -from chromadb.config import Settings +from chromadb.config import Settings, System import chromadb.server import chromadb.api +from chromadb.api import ServerAPI from chromadb.errors import ( ChromaError, InvalidUUIDError, @@ -99,12 +100,15 @@ def include_in_schema(path: str) -> bool: super().add_api_route(path, *args, **kwargs) +# TODO: add tenant/namespace to all routes class FastAPI(chromadb.server.Server): def __init__(self, settings: Settings): super().__init__(settings) Telemetry.SERVER_CONTEXT = ServerContext.FASTAPI self._app = fastapi.FastAPI(debug=True) - self._api: chromadb.api.API = chromadb.Client(settings) + self._system = System(settings) + self._api: ServerAPI = self._system.instance(ServerAPI) + self._system.start() self._app.middleware("http")(catch_exceptions_middleware) self._app.add_middleware( diff --git a/chromadb/test/auth/test_token_auth.py b/chromadb/test/auth/test_token_auth.py index 4e99baae306..50e88e296a9 100644 --- a/chromadb/test/auth/test_token_auth.py +++ b/chromadb/test/auth/test_token_auth.py @@ -5,7 +5,7 @@ import pytest from hypothesis import given, settings -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.config import System from chromadb.test.conftest import _fastapi_fixture @@ -64,7 +64,7 @@ def test_fastapi_server_token_auth(token_config: Dict[str, Any]) -> None: ) _sys: System = next(api) _sys.reset_state() - _api = _sys.instance(API) + _api = _sys.instance(ServerAPI) _api.heartbeat() assert _api.list_collections() == [] @@ -103,7 +103,7 @@ def test_invalid_token(tconf: Dict[str, Any], inval_tok: str) -> None: with pytest.raises(Exception) as e: _sys: System = next(api) _sys.reset_state() - _sys.instance(API) + _sys.instance(ServerAPI) assert "Invalid token" in str(e) @@ -131,7 +131,7 @@ def test_fastapi_server_token_auth_wrong_token( ) _sys: System = next(api) _sys.reset_state() - _api = _sys.instance(API) + _api = _sys.instance(ServerAPI) _api.heartbeat() with pytest.raises(Exception) as e: _api.list_collections() diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index af66ef2513f..238748bcbbb 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -22,7 +22,7 @@ from typing_extensions import Protocol import chromadb.server.fastapi -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.config import Settings, System from chromadb.db.mixins import embeddings_queue from chromadb.ingest import Producer @@ -98,7 +98,7 @@ def _run_server( uvicorn.run(server.app(), host="0.0.0.0", port=port, log_level="error") -def _await_server(api: API, attempts: int = 0) -> None: +def _await_server(api: ServerAPI, attempts: int = 0) -> None: try: api.heartbeat() except ConnectionError as e: @@ -172,7 +172,7 @@ def _fastapi_fixture( chroma_client_auth_token_transport_header=chroma_client_auth_token_transport_header, ) system = System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) system.start() _await_server(api) yield system @@ -198,7 +198,7 @@ def basic_http_client() -> Generator[System, None, None]: allow_reset=True, ) system = System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) _await_server(api) system.start() yield system @@ -361,41 +361,43 @@ def system_fixtures_wrong_auth() -> List[Callable[[], Generator[System, None, No @pytest.fixture(scope="module", params=system_fixtures_wrong_auth()) -def system_wrong_auth(request: pytest.FixtureRequest) -> Generator[API, None, None]: +def system_wrong_auth( + request: pytest.FixtureRequest, +) -> Generator[ServerAPI, None, None]: yield next(request.param()) @pytest.fixture(scope="module", params=system_fixtures()) -def system(request: pytest.FixtureRequest) -> Generator[API, None, None]: +def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: yield next(request.param()) @pytest.fixture(scope="module", params=system_fixtures_auth()) -def system_auth(request: pytest.FixtureRequest) -> Generator[API, None, None]: +def system_auth(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: yield next(request.param()) @pytest.fixture(scope="function") -def api(system: System) -> Generator[API, None, None]: +def api(system: System) -> Generator[ServerAPI, None, None]: system.reset_state() - api = system.instance(API) + api = system.instance(ServerAPI) yield api @pytest.fixture(scope="function") def api_wrong_cred( system_wrong_auth: System, -) -> Generator[API, None, None]: +) -> Generator[ServerAPI, None, None]: system_wrong_auth.reset_state() - api = system_wrong_auth.instance(API) + api = system_wrong_auth.instance(ServerAPI) yield api @pytest.fixture(scope="function") -def api_with_server_auth(system_auth: System) -> Generator[API, None, None]: +def api_with_server_auth(system_auth: System) -> Generator[ServerAPI, None, None]: _sys = system_auth _sys.reset_state() - api = _sys.instance(API) + api = _sys.instance(ServerAPI) yield api diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index 1980ed2a9d9..5f8991b00ed 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -5,7 +5,7 @@ import pytest import hypothesis.strategies as st from hypothesis import given, settings -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.types import Embeddings, Metadatas import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants @@ -17,7 +17,7 @@ @given(collection=collection_st, record_set=strategies.recordsets(collection_st)) @settings(deadline=None) def test_add( - api: API, + api: ServerAPI, collection: strategies.Collection, record_set: strategies.RecordSet, ) -> None: @@ -69,7 +69,7 @@ def create_large_recordset( @given(collection=collection_st) @settings(deadline=None, max_examples=1) -def test_add_large(api: API, collection: strategies.Collection) -> None: +def test_add_large(api: ServerAPI, collection: strategies.Collection) -> None: api.reset() record_set = create_large_recordset( min_size=api.max_batch_size, @@ -99,7 +99,7 @@ def test_add_large(api: API, collection: strategies.Collection) -> None: @given(collection=collection_st) @settings(deadline=None, max_examples=1) -def test_add_large_exceeding(api: API, collection: strategies.Collection) -> None: +def test_add_large_exceeding(api: ServerAPI, collection: strategies.Collection) -> None: api.reset() record_set = create_large_recordset( min_size=api.max_batch_size, @@ -126,7 +126,7 @@ def test_add_large_exceeding(api: API, collection: strategies.Collection) -> Non reason="This is expected to fail right now. We should change the API to sort the \ ids by input order." ) -def test_out_of_order_ids(api: API) -> None: +def test_out_of_order_ids(api: ServerAPI) -> None: api.reset() ooo_ids = [ "40", @@ -165,7 +165,7 @@ def test_out_of_order_ids(api: API) -> None: assert get_ids == ooo_ids -def test_add_partial(api: API) -> None: +def test_add_partial(api: ServerAPI) -> None: """Tests adding a record set with some of the fields set to None.""" api.reset() diff --git a/chromadb/test/property/test_collections.py b/chromadb/test/property/test_collections.py index 60e3de7592c..2d41c62ef80 100644 --- a/chromadb/test/property/test_collections.py +++ b/chromadb/test/property/test_collections.py @@ -2,7 +2,7 @@ import logging import hypothesis.strategies as st import chromadb.test.property.strategies as strategies -from chromadb.api import API +from chromadb.api import ClientAPI import chromadb.api.types as types from hypothesis.stateful import ( Bundle, @@ -23,7 +23,7 @@ class CollectionStateMachine(RuleBasedStateMachine): collections = Bundle("collections") - def __init__(self, api: API): + def __init__(self, api: ClientAPI): super().__init__() self.model = {} self.api = api @@ -203,6 +203,6 @@ def modify_coll( return multiple(coll) -def test_collections(caplog: pytest.LogCaptureFixture, api: API) -> None: +def test_collections(caplog: pytest.LogCaptureFixture, api: ClientAPI) -> None: caplog.set_level(logging.ERROR) run_state_machine_as_test(lambda: CollectionStateMachine(api)) # type: ignore diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 529fe02dda7..11780d4d675 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -5,14 +5,14 @@ import subprocess import tempfile from types import ModuleType -from typing import Generator, List, Tuple, Dict, Any, Callable +from typing import Generator, List, Tuple, Dict, Any, Callable, Type from hypothesis import given, settings import hypothesis.strategies as st import pytest import json from urllib import request from chromadb import config -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.types import Documents, EmbeddingFunction, Embeddings import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants @@ -84,6 +84,12 @@ def patch_for_version( patch(collection, embeddings) +def api_import_for_version(module: Any, version: str) -> Type: # type: ignore + if packaging_version.Version(version) <= packaging_version.Version("0.4.14"): + return module.api.API # type: ignore + return module.api.ServerAPI # type: ignore + + def configurations(versions: List[str]) -> List[Tuple[str, Settings]]: return [ ( @@ -197,13 +203,13 @@ def persist_generated_data_with_old_version( try: old_module = switch_to_version(version) system = old_module.config.System(settings) - api: API = system.instance(API) + api = system.instance(api_import_for_version(old_module, version)) system.start() api.reset() coll = api.create_collection( name=collection_strategy.name, - metadata=collection_strategy.metadata, # type: ignore + metadata=collection_strategy.metadata, # In order to test old versions, we can't rely on the not_implemented function embedding_function=not_implemented_ef(), ) @@ -288,7 +294,7 @@ def test_cycle_versions( # Switch to the current version (local working directory) and check the invariants # are preserved for the collection system = config.System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) system.start() coll = api.get_collection( name=collection_strategy.name, diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 0e402cca1a8..7fc2491c14b 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from chromadb.api.types import ID, Include, IDs import chromadb.errors as errors -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection import chromadb.test.property.strategies as strategies from hypothesis.stateful import ( @@ -64,7 +64,7 @@ class EmbeddingStateMachine(RuleBasedStateMachine): collection: Collection embedding_ids: Bundle[ID] = Bundle("embedding_ids") - def __init__(self, api: API): + def __init__(self, api: ServerAPI): super().__init__() self.api = api self._rules_strategy = strategies.DeterministicRuleStrategy(self) # type: ignore @@ -294,13 +294,13 @@ def on_state_change(self, new_state: str) -> None: pass -def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: API) -> None: +def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: ServerAPI) -> None: caplog.set_level(logging.ERROR) run_state_machine_as_test(lambda: EmbeddingStateMachine(api)) # type: ignore print_traces() -def test_multi_add(api: API) -> None: +def test_multi_add(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") coll.add(ids=["a"], embeddings=[[0.0]]) @@ -319,7 +319,7 @@ def test_multi_add(api: API) -> None: assert coll.count() == 0 -def test_dup_add(api: API) -> None: +def test_dup_add(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") with pytest.raises(errors.DuplicateIDError): @@ -328,7 +328,7 @@ def test_dup_add(api: API) -> None: coll.upsert(ids=["a", "a"], embeddings=[[0.0], [1.1]]) -def test_query_without_add(api: API) -> None: +def test_query_without_add(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") fields: Include = ["documents", "metadatas", "embeddings", "distances"] @@ -343,7 +343,7 @@ def test_query_without_add(api: API) -> None: assert all([len(result) == 0 for result in field_results]) -def test_get_non_existent(api: API) -> None: +def test_get_non_existent(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") result = coll.get(ids=["a"], include=["documents", "metadatas", "embeddings"]) @@ -355,7 +355,7 @@ def test_get_non_existent(api: API) -> None: # TODO: Use SQL escaping correctly internally @pytest.mark.xfail(reason="We don't properly escape SQL internally, causing problems") -def test_escape_chars_in_ids(api: API) -> None: +def test_escape_chars_in_ids(api: ServerAPI) -> None: api.reset() id = "\x1f" coll = api.create_collection(name="foo") @@ -375,7 +375,7 @@ def test_escape_chars_in_ids(api: API) -> None: {"where_document": {}, "where": {}}, ], ) -def test_delete_empty_fails(api: API, kwargs: dict): +def test_delete_empty_fails(api: ServerAPI, kwargs: dict): api.reset() coll = api.create_collection(name="foo") with pytest.raises(Exception) as e: @@ -398,7 +398,7 @@ def test_delete_empty_fails(api: API, kwargs: dict): }, ], ) -def test_delete_success(api: API, kwargs: dict): +def test_delete_success(api: ServerAPI, kwargs: dict): api.reset() coll = api.create_collection(name="foo") # Should not raise diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index ddcdefb0ed3..e55e5d18cf5 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, cast from hypothesis import given, settings, HealthCheck import pytest -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.test.property import invariants from chromadb.api.types import ( Document, @@ -165,7 +165,7 @@ def _filter_embedding_set( filters=st.lists(strategies.filters(collection_st, recordset_st), min_size=1), ) def test_filterable_metadata_get( - caplog, api: API, collection: strategies.Collection, record_set, filters + caplog, api: ServerAPI, collection: strategies.Collection, record_set, filters ) -> None: caplog.set_level(logging.ERROR) @@ -204,7 +204,7 @@ def test_filterable_metadata_get( ) def test_filterable_metadata_query( caplog: pytest.LogCaptureFixture, - api: API, + api: ServerAPI, collection: strategies.Collection, record_set: strategies.RecordSet, filters: List[strategies.Filter], @@ -257,7 +257,7 @@ def test_filterable_metadata_query( assert len(result_ids.intersection(expected_ids)) == len(result_ids) -def test_empty_filter(api: API) -> None: +def test_empty_filter(api: ServerAPI) -> None: """Test that a filter where no document matches returns an empty result""" api.reset() coll = api.create_collection(name="test") @@ -291,7 +291,7 @@ def test_empty_filter(api: API) -> None: assert res["metadatas"] == [[], []] -def test_boolean_metadata(api: API) -> None: +def test_boolean_metadata(api: ServerAPI) -> None: """Test that metadata with boolean values is correctly filtered""" api.reset() coll = api.create_collection(name="test") diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index ea95f684f60..e7b1f7017d1 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -6,7 +6,7 @@ import hypothesis.strategies as st import pytest import chromadb -from chromadb.api import API +from chromadb.api import ClientAPI, ServerAPI from chromadb.config import Settings, System import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants @@ -26,7 +26,7 @@ import shutil import tempfile -CreatePersistAPI = Callable[[], API] +CreatePersistAPI = Callable[[], ServerAPI] configurations = [ Settings( @@ -71,7 +71,7 @@ def test_persist( embeddings_strategy: strategies.RecordSet, ) -> None: system_1 = System(settings) - api_1 = system_1.instance(API) + api_1 = system_1.instance(ServerAPI) system_1.start() api_1.reset() @@ -103,7 +103,7 @@ def test_persist( del system_1 system_2 = System(settings) - api_2 = system_2.instance(API) + api_2 = system_2.instance(ServerAPI) system_2.start() coll = api_2.get_collection( @@ -133,7 +133,7 @@ def load_and_check( ) -> None: try: system = System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) system.start() coll = api.get_collection( @@ -157,7 +157,7 @@ class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates): class PersistEmbeddingsStateMachine(EmbeddingStateMachine): - def __init__(self, api: API, settings: Settings): + def __init__(self, api: ClientAPI, settings: Settings): self.api = api self.settings = settings self.last_persist_delay = 10 diff --git a/chromadb/test/stress/test_many_collections.py b/chromadb/test/stress/test_many_collections.py index 7e65c4b790d..29951fa452a 100644 --- a/chromadb/test/stress/test_many_collections.py +++ b/chromadb/test/stress/test_many_collections.py @@ -1,11 +1,11 @@ from typing import List import numpy as np -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection -def test_many_collections(api: API) -> None: +def test_many_collections(api: ServerAPI) -> None: """Test that we can create a large number of collections and that the system # remains responsive.""" api.reset() diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 8a12a1d9735..ed3c87ee682 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -21,7 +21,7 @@ @pytest.fixture def local_persist_api(): - yield chromadb.Client( + client = chromadb.Client( Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", @@ -33,6 +33,8 @@ def local_persist_api(): persist_directory=persist_dir, ), ) + yield client + client.clear_system_cache() if os.path.exists(persist_dir): shutil.rmtree(persist_dir, ignore_errors=True) @@ -40,7 +42,7 @@ def local_persist_api(): # https://docs.pytest.org/en/6.2.x/fixture.html#fixtures-can-be-requested-more-than-once-per-test-return-values-are-cached @pytest.fixture def local_persist_api_cache_bust(): - yield chromadb.Client( + client = chromadb.Client( Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", @@ -52,6 +54,8 @@ def local_persist_api_cache_bust(): persist_directory=persist_dir, ), ) + yield client + client.clear_system_cache() if os.path.exists(persist_dir): shutil.rmtree(persist_dir, ignore_errors=True) diff --git a/chromadb/test/test_chroma.py b/chromadb/test/test_chroma.py index 42b14411519..9d88ea8cc49 100644 --- a/chromadb/test/test_chroma.py +++ b/chromadb/test/test_chroma.py @@ -47,19 +47,21 @@ class GetAPITest(unittest.TestCase): @patch("chromadb.api.segment.SegmentAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_local(self, mock_api: Mock) -> None: - chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) + client = chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) assert mock_api.called + client.clear_system_cache() @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_local_db(self, mock_db: Mock) -> None: - chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) + client = chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) assert mock_db.called + client.clear_system_cache() @patch("chromadb.api.fastapi.FastAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_fastapi(self, mock: Mock) -> None: - chromadb.Client( + client = chromadb.Client( chromadb.config.Settings( chroma_api_impl="chromadb.api.fastapi.FastAPI", persist_directory="./foo", @@ -68,6 +70,7 @@ def test_fastapi(self, mock: Mock) -> None: ) ) assert mock.called + client.clear_system_cache() @patch("chromadb.api.fastapi.FastAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) @@ -78,7 +81,7 @@ def test_settings_pass_to_fastapi(self, mock: Mock) -> None: chroma_server_http_port="80", chroma_server_headers={"foo": "bar"}, ) - chromadb.Client(settings) + client = chromadb.Client(settings) # Check that the mock was called assert mock.called @@ -93,11 +96,12 @@ def test_settings_pass_to_fastapi(self, mock: Mock) -> None: # Check if the settings passed to the mock match the settings we used # raise Exception(passed_settings.settings) assert passed_settings.settings == settings + client.clear_system_cache() def test_legacy_values() -> None: with pytest.raises(ValueError): - chromadb.Client( + client = chromadb.Client( chromadb.config.Settings( chroma_api_impl="chromadb.api.local.LocalAPI", persist_directory="./foo", @@ -105,3 +109,4 @@ def test_legacy_values() -> None: chroma_server_http_port="80", ) ) + client.clear_system_cache() diff --git a/chromadb/test/test_client.py b/chromadb/test/test_client.py index 1164e1e699d..d4f1de9ae9f 100644 --- a/chromadb/test/test_client.py +++ b/chromadb/test/test_client.py @@ -1,37 +1,44 @@ +from typing import Generator import chromadb -from chromadb.api import API +from chromadb.api import ClientAPI import chromadb.server.fastapi import pytest import tempfile @pytest.fixture -def ephemeral_api() -> API: - return chromadb.EphemeralClient() +def ephemeral_api() -> Generator[ClientAPI, None, None]: + client = chromadb.EphemeralClient() + yield client + client.clear_system_cache() @pytest.fixture -def persistent_api() -> API: - return chromadb.PersistentClient( +def persistent_api() -> Generator[ClientAPI, None, None]: + client = chromadb.PersistentClient( path=tempfile.gettempdir() + "/test_server", ) + yield client + client.clear_system_cache() @pytest.fixture -def http_api() -> API: - return chromadb.HttpClient() +def http_api() -> Generator[ClientAPI, None, None]: + client = chromadb.HttpClient() + yield client + client.clear_system_cache() -def test_ephemeral_client(ephemeral_api: API) -> None: +def test_ephemeral_client(ephemeral_api: ClientAPI) -> None: settings = ephemeral_api.get_settings() assert settings.is_persistent is False -def test_persistent_client(persistent_api: API) -> None: +def test_persistent_client(persistent_api: ClientAPI) -> None: settings = persistent_api.get_settings() assert settings.is_persistent is True -def test_http_client(http_api: API) -> None: +def test_http_client(http_api: ClientAPI) -> None: settings = http_api.get_settings() assert settings.chroma_api_impl == "chromadb.api.fastapi.FastAPI" diff --git a/chromadb/test/test_multithreaded.py b/chromadb/test/test_multithreaded.py index 57c259dad99..c0b05e88324 100644 --- a/chromadb/test/test_multithreaded.py +++ b/chromadb/test/test_multithreaded.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, cast import numpy as np -from chromadb.api import API +from chromadb.api import ServerAPI import chromadb.test.property.invariants as invariants from chromadb.test.property.strategies import RecordSet from chromadb.test.property.strategies import test_hnsw_config @@ -37,7 +37,7 @@ def generate_record_set(N: int, D: int) -> RecordSet: # Hypothesis is bad at generating large datasets so we manually generate data in # this test to test multithreaded add with larger datasets -def _test_multithreaded_add(api: API, N: int, D: int, num_workers: int) -> None: +def _test_multithreaded_add(api: ServerAPI, N: int, D: int, num_workers: int) -> None: records_set = generate_record_set(N, D) ids = records_set["ids"] embeddings = records_set["embeddings"] @@ -95,7 +95,9 @@ def _test_multithreaded_add(api: API, N: int, D: int, num_workers: int) -> None: ) -def _test_interleaved_add_query(api: API, N: int, D: int, num_workers: int) -> None: +def _test_interleaved_add_query( + api: ServerAPI, N: int, D: int, num_workers: int +) -> None: """Test that will use multiple threads to interleave operations on the db and verify they work correctly""" api.reset() @@ -207,14 +209,14 @@ def perform_operation( ) -def test_multithreaded_add(api: API) -> None: +def test_multithreaded_add(api: ServerAPI) -> None: for i in range(3): num_workers = random.randint(2, multiprocessing.cpu_count() * 2) N, D = generate_data_shape() _test_multithreaded_add(api, N, D, num_workers) -def test_interleaved_add_query(api: API) -> None: +def test_interleaved_add_query(api: ServerAPI) -> None: for i in range(3): num_workers = random.randint(2, multiprocessing.cpu_count() * 2) N, D = generate_data_shape() diff --git a/chromadb/utils/batch_utils.py b/chromadb/utils/batch_utils.py index c8c1ac1e476..9c588270f25 100644 --- a/chromadb/utils/batch_utils.py +++ b/chromadb/utils/batch_utils.py @@ -1,5 +1,5 @@ from typing import Optional, Tuple, List -from chromadb.api import API +from chromadb.api import BaseAPI from chromadb.api.types import ( Documents, Embeddings, @@ -9,7 +9,7 @@ def create_batches( - api: API, + api: BaseAPI, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None,