From 0552704915985849047289948093735d0592faf4 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Mon, 23 Oct 2023 17:58:46 -0700 Subject: [PATCH] [STACKED #1255] [ENH] Add multitenancy (#1244) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - ... - New functionality - Adds multitenancy and databases as first class concepts by migrating the db and plumbing it through the API into the sysdb - We now treat the "System" as a singleton-per-path and create a wrapper API object that proxies to it with context on the tenant/database. In this way the server is context-unaware about the connection. ## Test plan *How are these changes tested?* Unit Tests were added for new client tenant/database behavior Property tests were added for the new tenant/database behavior by subclassing the collection state machine and switching the tenant/database as a state machine transition. - [ ] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes I will add a section to the docs about multitenancy and how to use it. We can remove warnings about the client being a singleton. --- chromadb/__init__.py | 68 ++- chromadb/api/__init__.py | 145 +++++- chromadb/api/client.py | 449 ++++++++++++++++++ chromadb/api/fastapi.py | 116 ++++- chromadb/api/models/Collection.py | 6 +- chromadb/api/segment.py | 71 ++- chromadb/config.py | 6 +- chromadb/db/impl/grpc/client.py | 64 ++- chromadb/db/impl/grpc/server.py | 172 ++++++- chromadb/db/impl/sqlite.py | 3 - chromadb/db/mixins/sysdb.py | 150 +++++- chromadb/db/system.py | 41 +- .../sysdb/00004-tenants-databases.sqlite.sql | 29 ++ chromadb/proto/chroma_pb2.py | 74 +-- chromadb/proto/chroma_pb2.pyi | 16 + chromadb/proto/coordinator_pb2.py | 62 ++- chromadb/proto/coordinator_pb2.pyi | 70 ++- chromadb/proto/coordinator_pb2_grpc.py | 180 +++++++ chromadb/server/fastapi/__init__.py | 93 +++- chromadb/server/fastapi/types.py | 8 + chromadb/test/auth/test_token_auth.py | 8 +- chromadb/test/client/test_database_tenant.py | 77 +++ .../test_multiple_clients_concurrency.py | 47 ++ chromadb/test/conftest.py | 36 +- chromadb/test/db/test_system.py | 282 ++++++++++- chromadb/test/property/strategies.py | 2 + chromadb/test/property/test_add.py | 12 +- chromadb/test/property/test_collections.py | 38 +- .../test_collections_with_database_tenant.py | 102 ++++ .../property/test_cross_version_persist.py | 16 +- chromadb/test/property/test_embeddings.py | 20 +- chromadb/test/property/test_filtering.py | 10 +- chromadb/test/property/test_persist.py | 12 +- chromadb/test/stress/test_many_collections.py | 4 +- chromadb/test/test_api.py | 8 +- chromadb/test/test_chroma.py | 15 +- chromadb/test/test_client.py | 29 +- chromadb/test/test_multithreaded.py | 12 +- chromadb/types.py | 10 + chromadb/utils/batch_utils.py | 4 +- idl/chromadb/proto/chroma.proto | 10 + idl/chromadb/proto/coordinator.proto | 41 +- requirements.txt | 2 +- 43 files changed, 2380 insertions(+), 240 deletions(-) create mode 100644 chromadb/api/client.py create mode 100644 chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql create mode 100644 chromadb/test/client/test_database_tenant.py create mode 100644 chromadb/test/client/test_multiple_clients_concurrency.py create mode 100644 chromadb/test/property/test_collections_with_database_tenant.py diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 599fb94dd45..eeb5eaf5899 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -1,8 +1,10 @@ from typing import Dict import logging +from chromadb.api.client import Client as ClientCreator +from chromadb.api.client import AdminClient as AdminClientCreator import chromadb.config -from chromadb.config import Settings, System -from chromadb.api import API +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings +from chromadb.api import AdminAPI, ClientAPI from chromadb.api.models.Collection import Collection from chromadb.api.types import ( CollectionMetadata, @@ -35,9 +37,6 @@ "QueryResult", "GetResult", ] -from chromadb.telemetry.product.events import ClientStartEvent -from chromadb.telemetry.product import ProductTelemetryClient - logger = logging.getLogger(__name__) @@ -55,7 +54,7 @@ is_client = False try: - from chromadb.is_thin_client import is_thin_client # type: ignore + from chromadb.is_thin_client import is_thin_client is_client = is_thin_client except ImportError: @@ -95,28 +94,43 @@ def get_settings() -> Settings: return __settings -def EphemeralClient(settings: Settings = Settings()) -> API: +def EphemeralClient( + settings: Settings = Settings(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, +) -> ClientAPI: """ Creates an in-memory instance of Chroma. This is useful for testing and development, but not recommended for production use. + + Args: + tenant: The tenant to use for this client. Defaults to the default tenant. + database: The database to use for this client. Defaults to the default database. """ settings.is_persistent = False - return Client(settings) + return ClientCreator(settings=settings, tenant=tenant, database=database) -def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) -> API: +def PersistentClient( + path: str = "./chroma", + settings: Settings = Settings(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, +) -> ClientAPI: """ Creates a persistent instance of Chroma that saves to disk. This is useful for testing and development, but not recommended for production use. Args: path: The directory to save Chroma's data to. Defaults to "./chroma". + tenant: The tenant to use for this client. Defaults to the default tenant. + database: The database to use for this client. Defaults to the default database. """ settings.persist_directory = path settings.is_persistent = True - return Client(settings) + return ClientCreator(tenant=tenant, database=database, settings=settings) def HttpClient( @@ -125,7 +139,9 @@ def HttpClient( ssl: bool = False, headers: Dict[str, str] = {}, settings: Settings = Settings(), -) -> API: + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, +) -> ClientAPI: """ Creates a client that connects to a remote Chroma server. This supports many clients connecting to the same server, and is the recommended way to @@ -136,6 +152,8 @@ def HttpClient( port: The port of the Chroma server. Defaults to "8000". ssl: Whether to use SSL to connect to the Chroma server. Defaults to False. headers: A dictionary of headers to send to the Chroma server. Defaults to {}. + tenant: The tenant to use for this client. Defaults to the default tenant. + database: The database to use for this client. Defaults to the default database. """ settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI" @@ -144,19 +162,29 @@ def HttpClient( settings.chroma_server_ssl_enabled = ssl settings.chroma_server_headers = headers - return Client(settings) + return ClientCreator(tenant=tenant, database=database, settings=settings) -def Client(settings: Settings = __settings) -> API: - """Return a running chroma.API instance""" +def Client( + settings: Settings = __settings, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, +) -> ClientAPI: + """ + Return a running chroma.API instance - system = System(settings) + tenant: The tenant to use for this client. Defaults to the default tenant. + database: The database to use for this client. Defaults to the default database. - product_telemetry_client = system.instance(ProductTelemetryClient) - api = system.instance(API) + """ + + return ClientCreator(tenant=tenant, database=database, settings=settings) - system.start() - product_telemetry_client.capture(ClientStartEvent()) +def AdminClient(settings: Settings = Settings()) -> AdminAPI: + """ - return api + Creates an admin client that can be used to create tenants and databases. + + """ + return AdminClientCreator(settings=settings) diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 50f2ff1ecef..ae39b7b13f1 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from typing import Sequence, Optional from uuid import UUID + +from overrides import override +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT from chromadb.api.models.Collection import Collection from chromadb.api.types import ( CollectionMetadata, @@ -16,10 +19,11 @@ WhereDocument, ) from chromadb.config import Component, Settings +from chromadb.types import Database, Tenant import chromadb.utils.embedding_functions as ef -class API(Component, ABC): +class BaseAPI(ABC): @abstractmethod def heartbeat(self) -> int: """Get the current time in nanoseconds since epoch. @@ -371,10 +375,10 @@ def get_version(self) -> str: @abstractmethod def get_settings(self) -> Settings: - """Get the settings used to initialize the client. + """Get the settings used to initialize. Returns: - Settings: The settings used to initialize the client. + Settings: The settings used to initialize. """ pass @@ -385,3 +389,138 @@ def max_batch_size(self) -> int: """Return the maximum number of records that can be submitted in a single call to submit_embeddings.""" pass + + +class ClientAPI(BaseAPI, ABC): + tenant: str + database: str + + @abstractmethod + def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None: + """Set the tenant and database for the client. Raises an error if the tenant or + database does not exist. + + Args: + tenant: The tenant to set. + database: The database to set. + + """ + pass + + @abstractmethod + def set_database(self, database: str) -> None: + """Set the database for the client. Raises an error if the database does not exist. + + Args: + database: The database to set. + + """ + pass + + @staticmethod + @abstractmethod + def clear_system_cache() -> None: + """Clear the system cache so that new systems can be created for an existing path. + This should only be used for testing purposes.""" + pass + + +class AdminAPI(ABC): + @abstractmethod + def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + """Create a new database. Raises an error if the database already exists. + + Args: + database: The name of the database to create. + + """ + pass + + @abstractmethod + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + """Get a database. Raises an error if the database does not exist. + + Args: + database: The name of the database to get. + tenant: The tenant of the database to get. + + """ + pass + + @abstractmethod + def create_tenant(self, name: str) -> None: + """Create a new tenant. Raises an error if the tenant already exists. + + Args: + tenant: The name of the tenant to create. + + """ + pass + + @abstractmethod + def get_tenant(self, name: str) -> Tenant: + """Get a tenant. Raises an error if the tenant does not exist. + + Args: + tenant: The name of the tenant to get. + + """ + pass + + +class ServerAPI(BaseAPI, AdminAPI, Component): + """An API instance that extends the relevant Base API methods by passing + in a tenant and database. This is the root component of the Chroma System""" + + @abstractmethod + @override + def list_collections( + self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> Sequence[Collection]: + pass + + @abstractmethod + @override + def create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + pass + + @abstractmethod + @override + def get_collection( + self, + name: str, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + pass + + @abstractmethod + @override + def get_or_create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + pass + + @abstractmethod + @override + def delete_collection( + self, + name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: + pass diff --git a/chromadb/api/client.py b/chromadb/api/client.py new file mode 100644 index 00000000000..6e436cf499f --- /dev/null +++ b/chromadb/api/client.py @@ -0,0 +1,449 @@ +from typing import ClassVar, Dict, Optional, Sequence +from uuid import UUID + +from overrides import override +from chromadb.api import AdminAPI, ClientAPI, ServerAPI +from chromadb.api.types import ( + CollectionMetadata, + Documents, + EmbeddingFunction, + Embeddings, + GetResult, + IDs, + Include, + Metadatas, + QueryResult, +) +from chromadb.config import Settings, System +from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE +from chromadb.api.models.Collection import Collection +from chromadb.telemetry.product import ProductTelemetryClient +from chromadb.telemetry.product.events import ClientStartEvent +from chromadb.types import Database, Tenant, Where, WhereDocument +import chromadb.utils.embedding_functions as ef + + +class SharedSystemClient: + _identifer_to_system: ClassVar[Dict[str, System]] = {} + _identifier: str + + # region Initialization + def __init__( + self, + settings: Settings = Settings(), + ) -> None: + self._identifier = SharedSystemClient._get_identifier_from_settings(settings) + SharedSystemClient._create_system_if_not_exists(self._identifier, settings) + + @classmethod + def _create_system_if_not_exists( + cls, identifier: str, settings: Settings + ) -> System: + if identifier not in cls._identifer_to_system: + new_system = System(settings) + cls._identifer_to_system[identifier] = new_system + + new_system.instance(ProductTelemetryClient) + new_system.instance(ServerAPI) + + new_system.start() + else: + previous_system = cls._identifer_to_system[identifier] + + # For now, the settings must match + if previous_system.settings != settings: + raise ValueError( + f"An instance of Chroma already exists for {identifier} with different settings" + ) + + return cls._identifer_to_system[identifier] + + @staticmethod + def _get_identifier_from_settings(settings: Settings) -> str: + identifier = "" + api_impl = settings.chroma_api_impl + + if api_impl is None: + raise ValueError("Chroma API implementation must be set in settings") + elif api_impl == "chromadb.api.segment.SegmentAPI": + if settings.is_persistent: + identifier = settings.persist_directory + else: + identifier = ( + "ephemeral" # TODO: support pathing and multiple ephemeral clients + ) + elif api_impl == "chromadb.api.fastapi.FastAPI": + identifier = ( + f"{settings.chroma_server_host}:{settings.chroma_server_http_port}" + ) + else: + raise ValueError(f"Unsupported Chroma API implementation {api_impl}") + + return identifier + + @staticmethod + def _populate_data_from_system(system: System) -> str: + identifier = SharedSystemClient._get_identifier_from_settings(system.settings) + SharedSystemClient._identifer_to_system[identifier] = system + return identifier + + @classmethod + def from_system(cls, system: System) -> "SharedSystemClient": + """Create a client from an existing system. This is useful for testing and debugging.""" + + SharedSystemClient._populate_data_from_system(system) + instance = cls(system.settings) + return instance + + @staticmethod + def clear_system_cache() -> None: + SharedSystemClient._identifer_to_system = {} + + @property + def _system(self) -> System: + return SharedSystemClient._identifer_to_system[self._identifier] + + # endregion + + +class Client(SharedSystemClient, ClientAPI): + """A client for Chroma. This is the main entrypoint for interacting with Chroma. + A client internally stores its tenant and database and proxies calls to a + Server API instance of Chroma. It treats the Server API and corresponding System + as a singleton, so multiple clients connecting to the same resource will share the + same API instance. + + Client implementations should be implement their own API-caching strategies. + """ + + tenant: str = DEFAULT_TENANT + database: str = DEFAULT_DATABASE + + _server: ServerAPI + # An internal admin client for verifying that databases and tenants exist + _admin_client: AdminAPI + + # region Initialization + def __init__( + self, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + settings: Settings = Settings(), + ) -> None: + super().__init__(settings=settings) + self.tenant = tenant + self.database = database + # Create an admin client for verifying that databases and tenants exist + self._admin_client = AdminClient.from_system(self._system) + self._validate_tenant_database(tenant=tenant, database=database) + + # Get the root system component we want to interact with + self._server = self._system.instance(ServerAPI) + + # Submit event for a client start + telemetry_client = self._system.instance(ProductTelemetryClient) + telemetry_client.capture(ClientStartEvent()) + + @classmethod + @override + def from_system( + cls, + system: System, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> "Client": + SharedSystemClient._populate_data_from_system(system) + instance = cls(tenant=tenant, database=database, settings=system.settings) + return instance + + # endregion + + # region BaseAPI Methods + # Note - we could do this in less verbose ways, but they break type checking + @override + def heartbeat(self) -> int: + return self._server.heartbeat() + + @override + def list_collections(self) -> Sequence[Collection]: + return self._server.list_collections(tenant=self.tenant, database=self.database) + + @override + def create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + get_or_create: bool = False, + ) -> Collection: + return self._server.create_collection( + name=name, + metadata=metadata, + embedding_function=embedding_function, + tenant=self.tenant, + database=self.database, + get_or_create=get_or_create, + ) + + @override + def get_collection( + self, + name: str, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + ) -> Collection: + return self._server.get_collection( + name=name, + embedding_function=embedding_function, + tenant=self.tenant, + database=self.database, + ) + + @override + def get_or_create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + ) -> Collection: + return self._server.get_or_create_collection( + name=name, + metadata=metadata, + embedding_function=embedding_function, + tenant=self.tenant, + database=self.database, + ) + + @override + def _modify( + self, + id: UUID, + new_name: Optional[str] = None, + new_metadata: Optional[CollectionMetadata] = None, + ) -> None: + return self._server._modify( + id=id, + new_name=new_name, + new_metadata=new_metadata, + ) + + @override + def delete_collection( + self, + name: str, + ) -> None: + return self._server.delete_collection( + name=name, + tenant=self.tenant, + database=self.database, + ) + + # + # ITEM METHODS + # + + @override + def _add( + self, + ids: IDs, + collection_id: UUID, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + ) -> bool: + return self._server._add( + ids=ids, + collection_id=collection_id, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + ) + + @override + def _update( + self, + collection_id: UUID, + ids: IDs, + embeddings: Optional[Embeddings] = None, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + ) -> bool: + return self._server._update( + collection_id=collection_id, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + ) + + @override + def _upsert( + self, + collection_id: UUID, + ids: IDs, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + ) -> bool: + return self._server._upsert( + collection_id=collection_id, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + ) + + @override + def _count(self, collection_id: UUID) -> int: + return self._server._count( + collection_id=collection_id, + ) + + @override + def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: + return self._server._peek( + collection_id=collection_id, + n=n, + ) + + @override + def _get( + self, + collection_id: UUID, + ids: Optional[IDs] = None, + where: Optional[Where] = {}, + sort: Optional[str] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + page: Optional[int] = None, + page_size: Optional[int] = None, + where_document: Optional[WhereDocument] = {}, + include: Include = ["embeddings", "metadatas", "documents"], + ) -> GetResult: + return self._server._get( + collection_id=collection_id, + ids=ids, + where=where, + sort=sort, + limit=limit, + offset=offset, + page=page, + page_size=page_size, + where_document=where_document, + include=include, + ) + + def _delete( + self, + collection_id: UUID, + ids: Optional[IDs], + where: Optional[Where] = {}, + where_document: Optional[WhereDocument] = {}, + ) -> IDs: + return self._server._delete( + collection_id=collection_id, + ids=ids, + where=where, + where_document=where_document, + ) + + @override + def _query( + self, + collection_id: UUID, + query_embeddings: Embeddings, + n_results: int = 10, + where: Where = {}, + where_document: WhereDocument = {}, + include: Include = ["embeddings", "metadatas", "documents", "distances"], + ) -> QueryResult: + return self._server._query( + collection_id=collection_id, + query_embeddings=query_embeddings, + n_results=n_results, + where=where, + where_document=where_document, + include=include, + ) + + @override + def reset(self) -> bool: + return self._server.reset() + + @override + def get_version(self) -> str: + return self._server.get_version() + + @override + def get_settings(self) -> Settings: + return self._server.get_settings() + + @property + @override + def max_batch_size(self) -> int: + return self._server.max_batch_size + + # endregion + + # region ClientAPI Methods + + @override + def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None: + self._validate_tenant_database(tenant=tenant, database=database) + self.tenant = tenant + self.database = database + + @override + def set_database(self, database: str) -> None: + self._validate_tenant_database(tenant=self.tenant, database=database) + self.database = database + + def _validate_tenant_database(self, tenant: str, database: str) -> None: + try: + self._admin_client.get_tenant(name=tenant) + except Exception: + raise ValueError( + f"Could not connect to tenant {tenant}. Are you sure it exists?" + ) + + try: + self._admin_client.get_database(name=database, tenant=tenant) + except Exception: + raise ValueError( + f"Could not connect to database {database} for tenant {tenant}. Are you sure it exists?" + ) + + # endregion + + +class AdminClient(SharedSystemClient, AdminAPI): + _server: ServerAPI + + def __init__(self, settings: Settings = Settings()) -> None: + super().__init__(settings) + self._server = self._system.instance(ServerAPI) + + @override + def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + return self._server.create_database(name=name, tenant=tenant) + + @override + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + return self._server.get_database(name=name, tenant=tenant) + + @override + def create_tenant(self, name: str) -> None: + return self._server.create_tenant(name=name) + + @override + def get_tenant(self, name: str) -> Tenant: + return self._server.get_tenant(name=name) + + @classmethod + @override + def from_system( + cls, + system: System, + ) -> "AdminClient": + SharedSystemClient._populate_data_from_system(system) + instance = cls(settings=system.settings) + return instance diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 8db5bf889f7..fc5298c0720 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -8,8 +8,9 @@ from overrides import override import chromadb.errors as errors +from chromadb.types import Database, Tenant import chromadb.utils.embedding_functions as ef -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection from chromadb.api.types import ( Documents, @@ -30,7 +31,7 @@ ) from chromadb.auth.providers import RequestsClientAuthProtocolAdapter from chromadb.auth.registry import resolve_provider -from chromadb.config import Settings, System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.telemetry.opentelemetry import ( OpenTelemetryClient, OpenTelemetryGranularity, @@ -42,7 +43,7 @@ logger = logging.getLogger(__name__) -class FastAPI(API): +class FastAPI(ServerAPI): _settings: Settings _max_batch_size: int = -1 @@ -142,11 +143,68 @@ def heartbeat(self) -> int: raise_chroma_error(resp) return int(resp.json()["nanosecond heartbeat"]) + @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION) + @override + def create_database( + self, + name: str, + tenant: str = DEFAULT_TENANT, + ) -> None: + """Creates a database""" + resp = self._session.post( + self._api_url + "/databases", + data=json.dumps({"name": name}), + params={"tenant": tenant}, + ) + raise_chroma_error(resp) + + @trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION) + @override + def get_database( + self, + name: str, + tenant: str = DEFAULT_TENANT, + ) -> Database: + """Returns a database""" + resp = self._session.get( + self._api_url + "/databases/" + name, + params={"tenant": tenant}, + ) + raise_chroma_error(resp) + resp_json = resp.json() + return Database( + id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"] + ) + + @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION) + @override + def create_tenant(self, name: str) -> None: + resp = self._session.post( + self._api_url + "/tenants", + data=json.dumps({"name": name}), + ) + raise_chroma_error(resp) + + @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION) + @override + def get_tenant(self, name: str) -> Tenant: + resp = self._session.get( + self._api_url + "/tenants/" + name, + ) + raise_chroma_error(resp) + resp_json = resp.json() + return Tenant(name=resp_json["name"]) + @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) @override - def list_collections(self) -> Sequence[Collection]: + def list_collections( + self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> Sequence[Collection]: """Returns a list of all collections""" - resp = self._session.get(self._api_url + "/collections") + resp = self._session.get( + self._api_url + "/collections", + params={"tenant": tenant, "database": database}, + ) raise_chroma_error(resp) json_collections = resp.json() collections = [] @@ -163,13 +221,20 @@ def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: """Creates a collection""" resp = self._session.post( self._api_url + "/collections", data=json.dumps( - {"name": name, "metadata": metadata, "get_or_create": get_or_create} + { + "name": name, + "metadata": metadata, + "get_or_create": get_or_create, + } ), + params={"tenant": tenant, "database": database}, ) raise_chroma_error(resp) resp_json = resp.json() @@ -187,9 +252,14 @@ def get_collection( self, name: str, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: """Returns a collection""" - resp = self._session.get(self._api_url + "/collections/" + name) + resp = self._session.get( + self._api_url + "/collections/" + name, + params={"tenant": tenant, "database": database}, + ) raise_chroma_error(resp) resp_json = resp.json() return Collection( @@ -209,9 +279,16 @@ def get_or_create_collection( name: str, metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: return self.create_collection( - name, metadata, embedding_function, get_or_create=True + name, + metadata, + embedding_function, + get_or_create=True, + tenant=tenant, + database=database, ) @trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION) @@ -231,14 +308,25 @@ def _modify( @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override - def delete_collection(self, name: str) -> None: + def delete_collection( + self, + name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: """Deletes a collection""" - resp = self._session.delete(self._api_url + "/collections/" + name) + resp = self._session.delete( + self._api_url + "/collections/" + name, + params={"tenant": tenant, "database": database}, + ) raise_chroma_error(resp) @trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION) @override - def _count(self, collection_id: UUID) -> int: + def _count( + self, + collection_id: UUID, + ) -> int: """Returns the number of embeddings in the database""" resp = self._session.get( self._api_url + "/collections/" + str(collection_id) + "/count" @@ -248,7 +336,11 @@ def _count(self, collection_id: UUID) -> int: @trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION) @override - def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: + def _peek( + self, + collection_id: UUID, + n: int = 10, + ) -> GetResult: return self._get( collection_id, limit=n, diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index c11a04b1fa4..f605d9d9d84 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -33,19 +33,19 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from chromadb.api import API + from chromadb.api import ServerAPI class Collection(BaseModel): name: str id: UUID metadata: Optional[CollectionMetadata] = None - _client: "API" = PrivateAttr() + _client: "ServerAPI" = PrivateAttr() _embedding_function: Optional[EmbeddingFunction] = PrivateAttr() def __init__( self, - client: "API", + client: "ServerAPI", name: str, id: UUID, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 45dcefc6697..a411a125fce 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -1,5 +1,5 @@ -from chromadb.api import API -from chromadb.config import Settings, System +from chromadb.api import ServerAPI +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.db.system import SysDB from chromadb.segment import SegmentManager, MetadataReader, VectorReader from chromadb.telemetry.opentelemetry import ( @@ -77,7 +77,7 @@ def check_index_name(index_name: str) -> None: raise ValueError(msg) -class SegmentAPI(API): +class SegmentAPI(ServerAPI): """API implementation utilizing the new segment-based internal architecture""" _settings: Settings @@ -104,6 +104,34 @@ def __init__(self, system: System): def heartbeat(self) -> int: return int(time.time_ns()) + @override + def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + if len(name) < 3: + raise ValueError("Database name must be at least 3 characters long") + + self._sysdb.create_database( + id=uuid4(), + name=name, + tenant=tenant, + ) + + @override + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database: + return self._sysdb.get_database(name=name, tenant=tenant) + + @override + def create_tenant(self, name: str) -> None: + if len(name) < 3: + raise ValueError("Tenant name must be at least 3 characters long") + + self._sysdb.create_tenant( + name=name, + ) + + @override + def get_tenant(self, name: str) -> t.Tenant: + return self._sysdb.get_tenant(name=name) + # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. @@ -115,6 +143,8 @@ def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: if metadata is not None: validate_metadata(metadata) @@ -130,6 +160,8 @@ def create_collection( metadata=metadata, dimension=None, get_or_create=get_or_create, + tenant=tenant, + database=database, ) if created: @@ -163,12 +195,16 @@ def get_or_create_collection( name: str, metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: return self.create_collection( # type: ignore name=name, metadata=metadata, embedding_function=embedding_function, get_or_create=True, + tenant=tenant, + database=database, ) # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is @@ -180,8 +216,12 @@ def get_collection( self, name: str, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: - existing = self._sysdb.get_collections(name=name) + existing = self._sysdb.get_collections( + name=name, tenant=tenant, database=database + ) if existing: return Collection( @@ -196,9 +236,13 @@ def get_collection( @trace_method("SegmentAPI.list_collection", OpenTelemetryGranularity.OPERATION) @override - def list_collections(self) -> Sequence[Collection]: + def list_collections( + self, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Sequence[Collection]: collections = [] - db_collections = self._sysdb.get_collections() + db_collections = self._sysdb.get_collections(tenant=tenant, database=database) for db_collection in db_collections: collections.append( Collection( @@ -236,11 +280,20 @@ def _modify( @trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override - def delete_collection(self, name: str) -> None: - existing = self._sysdb.get_collections(name=name) + def delete_collection( + self, + name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: + existing = self._sysdb.get_collections( + name=name, tenant=tenant, database=database + ) if existing: - self._sysdb.delete_collection(existing[0]["id"]) + self._sysdb.delete_collection( + existing[0]["id"], tenant=tenant, database=database + ) for s in self._manager.delete_segments(existing[0]["id"]): self._sysdb.delete_segment(s) if existing and existing[0]["id"] in self._collection_cache: diff --git a/chromadb/config.py b/chromadb/config.py index 0a3e4864673..993b2a33d03 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -63,7 +63,8 @@ # TODO: Don't use concrete types here to avoid circular deps. Strings are fine for right here! _abstract_type_keys: Dict[str, str] = { - "chromadb.api.API": "chroma_api_impl", + "chromadb.api.API": "chroma_api_impl", # NOTE: this is to support legacy api construction. Use ServerAPI instead + "chromadb.api.ServerAPI": "chroma_api_impl", "chromadb.telemetry.product.ProductTelemetryClient": "chroma_product_telemetry_impl", "chromadb.ingest.Producer": "chroma_producer_impl", "chromadb.ingest.Consumer": "chroma_consumer_impl", @@ -74,6 +75,9 @@ "chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl", } +DEFAULT_TENANT = "default_tenant" +DEFAULT_DATABASE = "default_database" + class Settings(BaseSettings): # type: ignore environment: str = "" diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index 04d4302062a..e1b279528f0 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -1,7 +1,7 @@ from typing import List, Optional, Sequence, Tuple, Union, cast from uuid import UUID from overrides import overrides -from chromadb.config import System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System from chromadb.db.base import NotFoundError, UniqueConstraintError from chromadb.db.system import SysDB from chromadb.proto.convert import ( @@ -13,22 +13,28 @@ ) from chromadb.proto.coordinator_pb2 import ( CreateCollectionRequest, + CreateDatabaseRequest, CreateSegmentRequest, + CreateTenantRequest, DeleteCollectionRequest, DeleteSegmentRequest, GetCollectionsRequest, GetCollectionsResponse, + GetDatabaseRequest, GetSegmentsRequest, + GetTenantRequest, UpdateCollectionRequest, UpdateSegmentRequest, ) from chromadb.proto.coordinator_pb2_grpc import SysDBStub from chromadb.types import ( Collection, + Database, Metadata, OptionalArgument, Segment, SegmentScope, + Tenant, Unspecified, UpdateMetadata, ) @@ -71,6 +77,44 @@ def reset_state(self) -> None: self._sys_db_stub.ResetState(Empty()) return super().reset_state() + @overrides + def create_database( + self, id: UUID, name: str, tenant: str = DEFAULT_TENANT + ) -> None: + request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant) + response = self._sys_db_stub.CreateDatabase(request) + if response.status.code == 409: + raise UniqueConstraintError() + + @overrides + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + request = GetDatabaseRequest(name=name, tenant=tenant) + response = self._sys_db_stub.GetDatabase(request) + if response.status.code == 404: + raise NotFoundError() + return Database( + id=UUID(hex=response.database.id), + name=response.database.name, + tenant=response.database.tenant, + ) + + @overrides + def create_tenant(self, name: str) -> None: + request = CreateTenantRequest(name=name) + response = self._sys_db_stub.CreateTenant(request) + if response.status.code == 409: + raise UniqueConstraintError() + + @overrides + def get_tenant(self, name: str) -> Tenant: + request = GetTenantRequest(name=name) + response = self._sys_db_stub.GetTenant(request) + if response.status.code == 404: + raise NotFoundError() + return Tenant( + name=response.tenant.name, + ) + @overrides def create_segment(self, segment: Segment) -> None: proto_segment = to_proto_segment(segment) @@ -164,6 +208,8 @@ def create_collection( metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: request = CreateCollectionRequest( id=id.hex, @@ -171,6 +217,8 @@ def create_collection( metadata=to_proto_update_metadata(metadata) if metadata else None, dimension=dimension, get_or_create=get_or_create, + tenant=tenant, + database=database, ) response = self._sys_db_stub.CreateCollection(request) if response.status.code == 409: @@ -179,9 +227,13 @@ def create_collection( return collection, response.created @overrides - def delete_collection(self, id: UUID) -> None: + def delete_collection( + self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> None: request = DeleteCollectionRequest( id=id.hex, + tenant=tenant, + database=database, ) response = self._sys_db_stub.DeleteCollection(request) if response.status.code == 404: @@ -193,11 +245,15 @@ def get_collections( id: Optional[UUID] = None, topic: Optional[str] = None, name: Optional[str] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: request = GetCollectionsRequest( id=id.hex if id else None, topic=topic, name=name, + tenant=tenant, + database=database, ) response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request) results: List[Collection] = [] @@ -243,7 +299,9 @@ def update_collection( request.ClearField("metadata") request.reset_metadata = True - self._sys_db_stub.UpdateCollection(request) + response = self._sys_db_stub.UpdateCollection(request) + if response.status.code == 404: + raise NotFoundError() def reset_and_wait_for_ready(self) -> None: self._sys_db_stub.ResetState(Empty(), wait_for_ready=True) diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index 436a57fc167..1a71929214e 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -3,7 +3,7 @@ from uuid import UUID from overrides import overrides from chromadb.ingest import CollectionAssignmentPolicy -from chromadb.config import Component, System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System from chromadb.proto.convert import ( from_proto_metadata, from_proto_update_metadata, @@ -16,13 +16,18 @@ from chromadb.proto.coordinator_pb2 import ( CreateCollectionRequest, CreateCollectionResponse, + CreateDatabaseRequest, CreateSegmentRequest, DeleteCollectionRequest, DeleteSegmentRequest, GetCollectionsRequest, GetCollectionsResponse, + GetDatabaseRequest, + GetDatabaseResponse, GetSegmentsRequest, GetSegmentsResponse, + GetTenantRequest, + GetTenantResponse, UpdateCollectionRequest, UpdateSegmentRequest, ) @@ -43,7 +48,10 @@ class GrpcMockSysDB(SysDBServicer, Component): _server_port: int _assignment_policy: CollectionAssignmentPolicy _segments: Dict[str, Segment] = {} - _collections: Dict[str, Collection] = {} + _tenants_to_databases_to_collections: Dict[ + str, Dict[str, Dict[str, Collection]] + ] = {} + _tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {} def __init__(self, system: System): self._server_port = system.settings.require("chroma_server_grpc_port") @@ -66,9 +74,81 @@ def stop(self) -> None: @overrides def reset_state(self) -> None: self._segments = {} - self._collections = {} + self._tenants_to_databases_to_collections = {} + # Create defaults + self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {} + self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {} + self._tenants_to_database_to_id[DEFAULT_TENANT] = {} + self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0) return super().reset_state() + @overrides(check_signature=False) + def CreateDatabase( + self, request: CreateDatabaseRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + tenant = request.tenant + database = request.name + if tenant not in self._tenants_to_databases_to_collections: + return proto.ChromaResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database in self._tenants_to_databases_to_collections[tenant]: + return proto.ChromaResponse( + status=proto.Status( + code=409, reason=f"Database {database} already exists" + ) + ) + self._tenants_to_databases_to_collections[tenant][database] = {} + self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id) + return proto.ChromaResponse(status=proto.Status(code=200)) + + @overrides(check_signature=False) + def GetDatabase( + self, request: GetDatabaseRequest, context: grpc.ServicerContext + ) -> GetDatabaseResponse: + tenant = request.tenant + database = request.name + if tenant not in self._tenants_to_databases_to_collections: + return GetDatabaseResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return GetDatabaseResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + id = self._tenants_to_database_to_id[tenant][database] + return GetDatabaseResponse( + status=proto.Status(code=200), + database=proto.Database(id=id.hex, name=database, tenant=tenant), + ) + + @overrides(check_signature=False) + def CreateTenant( + self, request: CreateDatabaseRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + tenant = request.name + if tenant in self._tenants_to_databases_to_collections: + return proto.ChromaResponse( + status=proto.Status(code=409, reason=f"Tenant {tenant} already exists") + ) + self._tenants_to_databases_to_collections[tenant] = {} + self._tenants_to_database_to_id[tenant] = {} + return proto.ChromaResponse(status=proto.Status(code=200)) + + @overrides(check_signature=False) + def GetTenant( + self, request: GetTenantRequest, context: grpc.ServicerContext + ) -> GetTenantResponse: + tenant = request.name + if tenant not in self._tenants_to_databases_to_collections: + return GetTenantResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + return GetTenantResponse( + status=proto.Status(code=200), + tenant=proto.Tenant(name=tenant), + ) + # We are forced to use check_signature=False because the generated proto code # does not have type annotations for the request and response objects. # TODO: investigate generating types for the request and response objects @@ -171,9 +251,47 @@ def CreateCollection( self, request: CreateCollectionRequest, context: grpc.ServicerContext ) -> CreateCollectionResponse: collection_name = request.name - matches = [ - c for c in self._collections.values() if c["name"] == collection_name - ] + tenant = request.tenant + database = request.database + if tenant not in self._tenants_to_databases_to_collections: + return CreateCollectionResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return CreateCollectionResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + + # Check if the collection already exists globally by id + for ( + search_tenant, + databases, + ) in self._tenants_to_databases_to_collections.items(): + for search_database, search_collections in databases.items(): + if request.id in search_collections: + if ( + search_tenant != request.tenant + or search_database != request.database + ): + return CreateCollectionResponse( + status=proto.Status( + code=409, + reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", + ) + ) + elif not request.get_or_create: + # If the id exists for this tenant and database, and we are not doing a get_or_create, then + # we should return a 409 + return CreateCollectionResponse( + status=proto.Status( + code=409, + reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", + ) + ) + + # Check if the collection already exists in this database by name + collections = self._tenants_to_databases_to_collections[tenant][database] + matches = [c for c in collections.values() if c["name"] == collection_name] assert len(matches) <= 1 if len(matches) > 0: if request.get_or_create: @@ -201,7 +319,7 @@ def CreateCollection( dimension=request.dimension, topic=self._assignment_policy.assign_collection(id), ) - self._collections[request.id] = new_collection + collections[request.id] = new_collection return CreateCollectionResponse( status=proto.Status(code=200), collection=to_proto_collection(new_collection), @@ -213,8 +331,19 @@ def DeleteCollection( self, request: DeleteCollectionRequest, context: grpc.ServicerContext ) -> proto.ChromaResponse: collection_id = request.id - if collection_id in self._collections: - del self._collections[collection_id] + tenant = request.tenant + database = request.database + if tenant not in self._tenants_to_databases_to_collections: + return proto.ChromaResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return proto.ChromaResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + collections = self._tenants_to_databases_to_collections[tenant][database] + if collection_id in collections: + del collections[collection_id] return proto.ChromaResponse(status=proto.Status(code=200)) else: return proto.ChromaResponse( @@ -231,8 +360,20 @@ def GetCollections( target_topic = request.topic if request.HasField("topic") else None target_name = request.name if request.HasField("name") else None + tenant = request.tenant + database = request.database + if tenant not in self._tenants_to_databases_to_collections: + return GetCollectionsResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return GetCollectionsResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + collections = self._tenants_to_databases_to_collections[tenant][database] + found_collections = [] - for collection in self._collections.values(): + for collection in collections.values(): if target_id and collection["id"] != target_id: continue if target_topic and collection["topic"] != target_topic: @@ -251,14 +392,21 @@ def UpdateCollection( self, request: UpdateCollectionRequest, context: grpc.ServicerContext ) -> proto.ChromaResponse: id_to_update = UUID(request.id) - if id_to_update.hex not in self._collections: + # Find the collection with this id + collections = {} + for tenant, databases in self._tenants_to_databases_to_collections.items(): + for database, maybe_collections in databases.items(): + if id_to_update.hex in maybe_collections: + collections = maybe_collections + + if id_to_update.hex not in collections: return proto.ChromaResponse( status=proto.Status( code=404, reason=f"Collection {id_to_update} not found" ) ) else: - collection = self._collections[id_to_update.hex] + collection = collections[id_to_update.hex] if request.HasField("topic"): collection["topic"] = request.topic if request.HasField("name"): diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index 6652d21333a..52d6deb08fd 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -9,7 +9,6 @@ OpenTelemetryGranularity, trace_method, ) -from chromadb.utils.delete_file import delete_file import sqlite3 from overrides import override import pypika @@ -148,8 +147,6 @@ def reset_state(self) -> None: for row in cur.fetchall(): cur.execute(f"DROP TABLE IF EXISTS {row[0]}") self._conn_pool.close() - if self._is_persistent: - delete_file(self._db_file) self.start() super().reset_state() diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index d9deb144f66..bfae0d07692 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -4,7 +4,7 @@ from pypika import Table, Column from itertools import groupby -from chromadb.config import System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System from chromadb.db.base import ( Cursor, SqlDB, @@ -22,11 +22,13 @@ ) from chromadb.ingest import CollectionAssignmentPolicy, Producer from chromadb.types import ( + Database, OptionalArgument, Segment, Metadata, Collection, SegmentScope, + Tenant, Unspecified, UpdateMetadata, ) @@ -49,6 +51,91 @@ def start(self) -> None: super().start() self._producer = self._system.instance(Producer) + @override + def create_database( + self, id: UUID, name: str, tenant: str = DEFAULT_TENANT + ) -> None: + with self.tx() as cur: + # Get the tenant id for the tenant name and then insert the database with the id, name and tenant id + databases = Table("databases") + tenants = Table("tenants") + insert_database = ( + self.querybuilder() + .into(databases) + .columns(databases.id, databases.name, databases.tenant_id) + .insert( + ParameterValue(self.uuid_to_db(id)), + ParameterValue(name), + self.querybuilder() + .select(tenants.id) + .from_(tenants) + .where(tenants.id == ParameterValue(tenant)), + ) + ) + sql, params = get_sql(insert_database, self.parameter_format()) + try: + cur.execute(sql, params) + except self.unique_constraint_error() as e: + raise UniqueConstraintError( + f"Database {name} already exists for tenant {tenant}" + ) from e + + @override + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + with self.tx() as cur: + databases = Table("databases") + q = ( + self.querybuilder() + .from_(databases) + .select(databases.id, databases.name) + .where(databases.name == ParameterValue(name)) + .where(databases.tenant_id == ParameterValue(tenant)) + ) + sql, params = get_sql(q, self.parameter_format()) + row = cur.execute(sql, params).fetchone() + if not row: + raise NotFoundError(f"Database {name} not found for tenant {tenant}") + if row[0] is None: + raise NotFoundError(f"Database {name} not found for tenant {tenant}") + id: UUID = cast(UUID, self.uuid_from_db(row[0])) + return Database( + id=id, + name=row[1], + tenant=tenant, + ) + + @override + def create_tenant(self, name: str) -> None: + with self.tx() as cur: + tenants = Table("tenants") + insert_tenant = ( + self.querybuilder() + .into(tenants) + .columns(tenants.id) + .insert(ParameterValue(name)) + ) + sql, params = get_sql(insert_tenant, self.parameter_format()) + try: + cur.execute(sql, params) + except self.unique_constraint_error() as e: + raise UniqueConstraintError(f"Tenant {name} already exists") from e + + @override + def get_tenant(self, name: str) -> Tenant: + with self.tx() as cur: + tenants = Table("tenants") + q = ( + self.querybuilder() + .from_(tenants) + .select(tenants.id) + .where(tenants.id == ParameterValue(name)) + ) + sql, params = get_sql(q, self.parameter_format()) + row = cur.execute(sql, params).fetchone() + if not row: + raise NotFoundError(f"Tenant {name} not found") + return Tenant(name=name) + @override def create_segment(self, segment: Segment) -> None: add_attributes_to_current_span( @@ -106,6 +193,8 @@ def create_collection( metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: if id is None and not get_or_create: raise ValueError("id must be specified if get_or_create is False") @@ -117,7 +206,7 @@ def create_collection( } ) - existing = self.get_collections(name=name) + existing = self.get_collections(name=name, tenant=tenant, database=database) if existing: if get_or_create: collection = existing[0] @@ -126,7 +215,12 @@ def create_collection( collection["id"], metadata=metadata, ) - return self.get_collections(id=collection["id"])[0], False + return ( + self.get_collections( + id=collection["id"], tenant=tenant, database=database + )[0], + False, + ) else: raise UniqueConstraintError(f"Collection {name} already exists") @@ -137,6 +231,8 @@ def create_collection( with self.tx() as cur: collections = Table("collections") + databases = Table("databases") + insert_collection = ( self.querybuilder() .into(collections) @@ -145,12 +241,19 @@ def create_collection( collections.topic, collections.name, collections.dimension, + collections.database_id, ) .insert( ParameterValue(self.uuid_to_db(collection["id"])), ParameterValue(collection["topic"]), ParameterValue(collection["name"]), ParameterValue(collection["dimension"]), + # Get the database id for the database with the given name and tenant + self.querybuilder() + .select(databases.id) + .from_(databases) + .where(databases.name == ParameterValue(database)) + .where(databases.tenant_id == ParameterValue(tenant)), ) ) sql, params = get_sql(insert_collection, self.parameter_format()) @@ -256,8 +359,16 @@ def get_collections( id: Optional[UUID] = None, topic: Optional[str] = None, name: Optional[str] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: """Get collections by name, embedding function and/or metadata""" + + if name is not None and (tenant is None or database is None): + raise ValueError( + "If name is specified, tenant and database must also be specified in order to uniquely identify the collection" + ) + add_attributes_to_current_span( { "collection_id": str(id), @@ -265,6 +376,7 @@ def get_collections( "collection_name": name if name else "", } ) + collections_t = Table("collections") metadata_t = Table("collection_metadata") q = ( @@ -291,6 +403,17 @@ def get_collections( if name: q = q.where(collections_t.name == ParameterValue(name)) + if tenant and database: + databases_t = Table("databases") + q = q.where( + collections_t.database_id + == self.querybuilder() + .select(databases_t.id) + .from_(databases_t) + .where(databases_t.name == ParameterValue(database)) + .where(databases_t.tenant_id == ParameterValue(tenant)) + ) + with self.tx() as cur: sql, params = get_sql(q, self.parameter_format()) rows = cur.execute(sql, params).fetchall() @@ -341,7 +464,12 @@ def delete_segment(self, id: UUID) -> None: @trace_method("SqlSysDB.delete_collection", OpenTelemetryGranularity.ALL) @override - def delete_collection(self, id: UUID) -> None: + def delete_collection( + self, + id: UUID, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: """Delete a topic and all associated segments from the SysDB""" add_attributes_to_current_span( { @@ -349,10 +477,19 @@ def delete_collection(self, id: UUID) -> None: } ) t = Table("collections") + databases_t = Table("databases") q = ( self.querybuilder() .from_(t) .where(t.id == ParameterValue(self.uuid_to_db(id))) + .where( + t.database_id + == self.querybuilder() + .select(databases_t.id) + .from_(databases_t) + .where(databases_t.name == ParameterValue(database)) + .where(databases_t.tenant_id == ParameterValue(tenant)) + ) .delete() ) with self.tx() as cur: @@ -459,7 +596,10 @@ def update_collection( with self.tx() as cur: sql, params = get_sql(q, self.parameter_format()) if sql: # pypika emits a blank string if nothing to do - cur.execute(sql, params) + sql = sql + " RETURNING id" + result = cur.execute(sql, params) + if not result.fetchone(): + raise NotFoundError(f"Collection {id} not found") # TODO: Update to use better semantics where it's possible to update # individual keys without wiping all the existing metadata. diff --git a/chromadb/db/system.py b/chromadb/db/system.py index b4975339fbc..17606b740b6 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -3,6 +3,8 @@ from uuid import UUID from chromadb.types import ( Collection, + Database, + Tenant, Metadata, Segment, SegmentScope, @@ -10,15 +12,40 @@ Unspecified, UpdateMetadata, ) -from chromadb.config import Component +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component class SysDB(Component): """Data interface for Chroma's System database""" + @abstractmethod + def create_database( + self, id: UUID, name: str, tenant: str = DEFAULT_TENANT + ) -> None: + """Create a new database in the System database. Raises an Error if the Database + already exists.""" + pass + + @abstractmethod + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + """Get a database by name and tenant. Raises an Error if the Database does not + exist.""" + pass + + @abstractmethod + def create_tenant(self, name: str) -> None: + """Create a new tenant in the System database. The name must be unique. + Raises an Error if the Tenant already exists.""" + pass + + @abstractmethod + def get_tenant(self, name: str) -> Tenant: + """Get a tenant by name. Raises an Error if the Tenant does not exist.""" + pass + @abstractmethod def create_segment(self, segment: Segment) -> None: - """Create a new segment in the System database. Raises DuplicateError if the ID + """Create a new segment in the System database. Raises an Error if the ID already exists.""" pass @@ -60,6 +87,8 @@ def create_collection( metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: """Create a new collection any associated resources (Such as the necessary topics) in the SysDB. If get_or_create is True, the @@ -73,7 +102,9 @@ def create_collection( pass @abstractmethod - def delete_collection(self, id: UUID) -> None: + def delete_collection( + self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> None: """Delete a collection, topic, all associated segments and any associate resources from the SysDB and the system at large.""" pass @@ -84,8 +115,10 @@ def get_collections( id: Optional[UUID] = None, topic: Optional[str] = None, name: Optional[str] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: - """Find collections by id, topic or name""" + """Find collections by id, topic or name. If name is provided, tenant and database must also be provided.""" pass @abstractmethod diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql new file mode 100644 index 00000000000..43372bf97a8 --- /dev/null +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -0,0 +1,29 @@ +CREATE TABLE IF NOT EXISTS tenants ( + id TEXT PRIMARY KEY, + UNIQUE (id) +); + +CREATE TABLE IF NOT EXISTS databases ( + id TEXT PRIMARY KEY, -- unique globally + name TEXT NOT NULL, -- unique per tenant + tenant_id TEXT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, + UNIQUE (tenant_id, name) -- Ensure that a tenant has only one database with a given name +); + +CREATE TABLE IF NOT EXISTS collections_tmp ( + id TEXT PRIMARY KEY, -- unique globally + name TEXT NOT NULL, -- unique per database + topic TEXT NOT NULL, + dimension INTEGER, + database_id TEXT NOT NULL REFERENCES databases(id) ON DELETE CASCADE, + UNIQUE (name, database_id) +); + +-- Create default tenant and database +INSERT OR REPLACE INTO tenants (id) VALUES ('default_tenant'); -- The default tenant id is 'default_tenant' others are UUIDs +INSERT OR REPLACE INTO databases (id, name, tenant_id) VALUES ('00000000-0000-0000-0000-000000000000', 'default_database', 'default_tenant'); + +INSERT OR REPLACE INTO collections_tmp (id, name, topic, dimension, database_id) + SELECT id, name, topic, dimension, '00000000-0000-0000-0000-000000000000' FROM collections; +DROP TABLE collections; +ALTER TABLE collections_tmp RENAME TO collections; diff --git a/chromadb/proto/chroma_pb2.py b/chromadb/proto/chroma_pb2.py index bd069cc74f4..3644a70c74d 100644 --- a/chromadb/proto/chroma_pb2.py +++ b/chromadb/proto/chroma_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"&\n\x06Status\x12\x0e\n\x06reason\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\"0\n\x0e\x43hromaResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.chroma.Status\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"\x97\x01\n\nCollection\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05topic\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x05 \x01(\x05H\x01\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimension\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"&\n\x06Status\x12\x0e\n\x06reason\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\"0\n\x0e\x43hromaResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.chroma.Status\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"\x97\x01\n\nCollection\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05topic\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x05 \x01(\x05H\x01\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimension\"4\n\x08\x44\x61tabase\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"\x16\n\x06Tenant\x12\x0c\n\x04name\x18\x01 \x01(\t\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -22,12 +22,12 @@ DESCRIPTOR._options = None _UPDATEMETADATA_METADATAENTRY._options = None _UPDATEMETADATA_METADATAENTRY._serialized_options = b'8\001' - _globals['_OPERATION']._serialized_start=1650 - _globals['_OPERATION']._serialized_end=1706 - _globals['_SCALARENCODING']._serialized_start=1708 - _globals['_SCALARENCODING']._serialized_end=1748 - _globals['_SEGMENTSCOPE']._serialized_start=1750 - _globals['_SEGMENTSCOPE']._serialized_end=1790 + _globals['_OPERATION']._serialized_start=1728 + _globals['_OPERATION']._serialized_end=1784 + _globals['_SCALARENCODING']._serialized_start=1786 + _globals['_SCALARENCODING']._serialized_end=1826 + _globals['_SEGMENTSCOPE']._serialized_start=1828 + _globals['_SEGMENTSCOPE']._serialized_end=1868 _globals['_STATUS']._serialized_start=39 _globals['_STATUS']._serialized_end=77 _globals['_CHROMARESPONSE']._serialized_start=79 @@ -38,32 +38,36 @@ _globals['_SEGMENT']._serialized_end=419 _globals['_COLLECTION']._serialized_start=422 _globals['_COLLECTION']._serialized_end=573 - _globals['_UPDATEMETADATAVALUE']._serialized_start=575 - _globals['_UPDATEMETADATAVALUE']._serialized_end=673 - _globals['_UPDATEMETADATA']._serialized_start=676 - _globals['_UPDATEMETADATA']._serialized_end=826 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=750 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=826 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=829 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=1010 - _globals['_VECTOREMBEDDINGRECORD']._serialized_start=1012 - _globals['_VECTOREMBEDDINGRECORD']._serialized_end=1095 - _globals['_VECTORQUERYRESULT']._serialized_start=1097 - _globals['_VECTORQUERYRESULT']._serialized_end=1210 - _globals['_VECTORQUERYRESULTS']._serialized_start=1212 - _globals['_VECTORQUERYRESULTS']._serialized_end=1276 - _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1278 - _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1318 - _globals['_GETVECTORSREQUEST']._serialized_start=1320 - _globals['_GETVECTORSREQUEST']._serialized_end=1372 - _globals['_GETVECTORSRESPONSE']._serialized_start=1374 - _globals['_GETVECTORSRESPONSE']._serialized_end=1442 - _globals['_QUERYVECTORSREQUEST']._serialized_start=1445 - _globals['_QUERYVECTORSREQUEST']._serialized_end=1579 - _globals['_QUERYVECTORSRESPONSE']._serialized_start=1581 - _globals['_QUERYVECTORSRESPONSE']._serialized_end=1648 - _globals['_SEGMENTSERVER']._serialized_start=1793 - _globals['_SEGMENTSERVER']._serialized_end=1941 - _globals['_VECTORREADER']._serialized_start=1944 - _globals['_VECTORREADER']._serialized_end=2106 + _globals['_DATABASE']._serialized_start=575 + _globals['_DATABASE']._serialized_end=627 + _globals['_TENANT']._serialized_start=629 + _globals['_TENANT']._serialized_end=651 + _globals['_UPDATEMETADATAVALUE']._serialized_start=653 + _globals['_UPDATEMETADATAVALUE']._serialized_end=751 + _globals['_UPDATEMETADATA']._serialized_start=754 + _globals['_UPDATEMETADATA']._serialized_end=904 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=828 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=904 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=907 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=1088 + _globals['_VECTOREMBEDDINGRECORD']._serialized_start=1090 + _globals['_VECTOREMBEDDINGRECORD']._serialized_end=1173 + _globals['_VECTORQUERYRESULT']._serialized_start=1175 + _globals['_VECTORQUERYRESULT']._serialized_end=1288 + _globals['_VECTORQUERYRESULTS']._serialized_start=1290 + _globals['_VECTORQUERYRESULTS']._serialized_end=1354 + _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1356 + _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1396 + _globals['_GETVECTORSREQUEST']._serialized_start=1398 + _globals['_GETVECTORSREQUEST']._serialized_end=1450 + _globals['_GETVECTORSRESPONSE']._serialized_start=1452 + _globals['_GETVECTORSRESPONSE']._serialized_end=1520 + _globals['_QUERYVECTORSREQUEST']._serialized_start=1523 + _globals['_QUERYVECTORSREQUEST']._serialized_end=1657 + _globals['_QUERYVECTORSRESPONSE']._serialized_start=1659 + _globals['_QUERYVECTORSRESPONSE']._serialized_end=1726 + _globals['_SEGMENTSERVER']._serialized_start=1871 + _globals['_SEGMENTSERVER']._serialized_end=2019 + _globals['_VECTORREADER']._serialized_start=2022 + _globals['_VECTORREADER']._serialized_end=2184 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi index 733cae0a273..386f253e776 100644 --- a/chromadb/proto/chroma_pb2.pyi +++ b/chromadb/proto/chroma_pb2.pyi @@ -85,6 +85,22 @@ class Collection(_message.Message): dimension: int def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ...) -> None: ... +class Database(_message.Message): + __slots__ = ["id", "name", "tenant"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + tenant: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... + +class Tenant(_message.Message): + __slots__ = ["name"] + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + class UpdateMetadataValue(_message.Message): __slots__ = ["string_value", "int_value", "float_value"] STRING_VALUE_FIELD_NUMBER: _ClassVar[int] diff --git a/chromadb/proto/coordinator_pb2.py b/chromadb/proto/coordinator_pb2.py index 118405d423a..42039c2d23f 100644 --- a/chromadb/proto/coordinator_pb2.py +++ b/chromadb/proto/coordinator_pb2.py @@ -15,35 +15,47 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xc3\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"%\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"i\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xb6\x05\n\x05SysDB\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"A\n\x15\x43reateDatabaseRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"2\n\x12GetDatabaseRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\"Y\n\x13GetDatabaseResponse\x12\"\n\x08\x64\x61tabase\x18\x01 \x01(\x0b\x32\x10.chroma.Database\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"#\n\x13\x43reateTenantRequest\x12\x0c\n\x04name\x18\x02 \x01(\t\" \n\x10GetTenantRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"S\n\x11GetTenantResponse\x12\x1e\n\x06tenant\x18\x01 \x01(\x0b\x32\x0e.chroma.Tenant\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xe5\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x06 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x07 \x01(\tB\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"G\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\"\x8b\x01\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x04 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x05 \x01(\tB\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xd6\x07\n\x05SysDB\x12I\n\x0e\x43reateDatabase\x12\x1d.chroma.CreateDatabaseRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetDatabase\x12\x1a.chroma.GetDatabaseRequest\x1a\x1b.chroma.GetDatabaseResponse\"\x00\x12\x45\n\x0c\x43reateTenant\x12\x1b.chroma.CreateTenantRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12\x42\n\tGetTenant\x12\x18.chroma.GetTenantRequest\x1a\x19.chroma.GetTenantResponse\"\x00\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.coordinator_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATESEGMENTREQUEST']._serialized_start=102 - _globals['_CREATESEGMENTREQUEST']._serialized_end=158 - _globals['_DELETESEGMENTREQUEST']._serialized_start=160 - _globals['_DELETESEGMENTREQUEST']._serialized_end=194 - _globals['_GETSEGMENTSREQUEST']._serialized_start=197 - _globals['_GETSEGMENTSREQUEST']._serialized_end=391 - _globals['_GETSEGMENTSRESPONSE']._serialized_start=393 - _globals['_GETSEGMENTSRESPONSE']._serialized_end=481 - _globals['_UPDATESEGMENTREQUEST']._serialized_start=484 - _globals['_UPDATESEGMENTREQUEST']._serialized_end=734 - _globals['_CREATECOLLECTIONREQUEST']._serialized_start=737 - _globals['_CREATECOLLECTIONREQUEST']._serialized_end=932 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=934 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1049 - _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1051 - _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1088 - _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1090 - _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1195 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1197 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1294 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1297 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1519 - _globals['_SYSDB']._serialized_start=1522 - _globals['_SYSDB']._serialized_end=2216 + _globals['_CREATEDATABASEREQUEST']._serialized_start=102 + _globals['_CREATEDATABASEREQUEST']._serialized_end=167 + _globals['_GETDATABASEREQUEST']._serialized_start=169 + _globals['_GETDATABASEREQUEST']._serialized_end=219 + _globals['_GETDATABASERESPONSE']._serialized_start=221 + _globals['_GETDATABASERESPONSE']._serialized_end=310 + _globals['_CREATETENANTREQUEST']._serialized_start=312 + _globals['_CREATETENANTREQUEST']._serialized_end=347 + _globals['_GETTENANTREQUEST']._serialized_start=349 + _globals['_GETTENANTREQUEST']._serialized_end=381 + _globals['_GETTENANTRESPONSE']._serialized_start=383 + _globals['_GETTENANTRESPONSE']._serialized_end=466 + _globals['_CREATESEGMENTREQUEST']._serialized_start=468 + _globals['_CREATESEGMENTREQUEST']._serialized_end=524 + _globals['_DELETESEGMENTREQUEST']._serialized_start=526 + _globals['_DELETESEGMENTREQUEST']._serialized_end=560 + _globals['_GETSEGMENTSREQUEST']._serialized_start=563 + _globals['_GETSEGMENTSREQUEST']._serialized_end=757 + _globals['_GETSEGMENTSRESPONSE']._serialized_start=759 + _globals['_GETSEGMENTSRESPONSE']._serialized_end=847 + _globals['_UPDATESEGMENTREQUEST']._serialized_start=850 + _globals['_UPDATESEGMENTREQUEST']._serialized_end=1100 + _globals['_CREATECOLLECTIONREQUEST']._serialized_start=1103 + _globals['_CREATECOLLECTIONREQUEST']._serialized_end=1332 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=1334 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1449 + _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1451 + _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1522 + _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1525 + _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1664 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1666 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1763 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1766 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1988 + _globals['_SYSDB']._serialized_start=1991 + _globals['_SYSDB']._serialized_end=2973 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/coordinator_pb2.pyi b/chromadb/proto/coordinator_pb2.pyi index 6b9c974e424..81545e4e283 100644 --- a/chromadb/proto/coordinator_pb2.pyi +++ b/chromadb/proto/coordinator_pb2.pyi @@ -7,6 +7,52 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor +class CreateDatabaseRequest(_message.Message): + __slots__ = ["id", "name", "tenant"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + tenant: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... + +class GetDatabaseRequest(_message.Message): + __slots__ = ["name", "tenant"] + NAME_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + name: str + tenant: str + def __init__(self, name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... + +class GetDatabaseResponse(_message.Message): + __slots__ = ["database", "status"] + DATABASE_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + database: _chroma_pb2.Database + status: _chroma_pb2.Status + def __init__(self, database: _Optional[_Union[_chroma_pb2.Database, _Mapping]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + +class CreateTenantRequest(_message.Message): + __slots__ = ["name"] + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class GetTenantRequest(_message.Message): + __slots__ = ["name"] + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class GetTenantResponse(_message.Message): + __slots__ = ["tenant", "status"] + TENANT_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + tenant: _chroma_pb2.Tenant + status: _chroma_pb2.Status + def __init__(self, tenant: _Optional[_Union[_chroma_pb2.Tenant, _Mapping]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + class CreateSegmentRequest(_message.Message): __slots__ = ["segment"] SEGMENT_FIELD_NUMBER: _ClassVar[int] @@ -60,18 +106,22 @@ class UpdateSegmentRequest(_message.Message): def __init__(self, id: _Optional[str] = ..., topic: _Optional[str] = ..., reset_topic: bool = ..., collection: _Optional[str] = ..., reset_collection: bool = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... class CreateCollectionRequest(_message.Message): - __slots__ = ["id", "name", "metadata", "dimension", "get_or_create"] + __slots__ = ["id", "name", "metadata", "dimension", "get_or_create", "tenant", "database"] ID_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] METADATA_FIELD_NUMBER: _ClassVar[int] DIMENSION_FIELD_NUMBER: _ClassVar[int] GET_OR_CREATE_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] id: str name: str metadata: _chroma_pb2.UpdateMetadata dimension: int get_or_create: bool - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., get_or_create: bool = ...) -> None: ... + tenant: str + database: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., get_or_create: bool = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ...) -> None: ... class CreateCollectionResponse(_message.Message): __slots__ = ["collection", "created", "status"] @@ -84,20 +134,28 @@ class CreateCollectionResponse(_message.Message): def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., created: bool = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... class DeleteCollectionRequest(_message.Message): - __slots__ = ["id"] + __slots__ = ["id", "tenant", "database"] ID_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] id: str - def __init__(self, id: _Optional[str] = ...) -> None: ... + tenant: str + database: str + def __init__(self, id: _Optional[str] = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ...) -> None: ... class GetCollectionsRequest(_message.Message): - __slots__ = ["id", "name", "topic"] + __slots__ = ["id", "name", "topic", "tenant", "database"] ID_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] TOPIC_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] id: str name: str topic: str - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ...) -> None: ... + tenant: str + database: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ...) -> None: ... class GetCollectionsResponse(_message.Message): __slots__ = ["collections", "status"] diff --git a/chromadb/proto/coordinator_pb2_grpc.py b/chromadb/proto/coordinator_pb2_grpc.py index a3a1e03227b..117c568c715 100644 --- a/chromadb/proto/coordinator_pb2_grpc.py +++ b/chromadb/proto/coordinator_pb2_grpc.py @@ -16,6 +16,26 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ + self.CreateDatabase = channel.unary_unary( + "/chroma.SysDB/CreateDatabase", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.GetDatabase = channel.unary_unary( + "/chroma.SysDB/GetDatabase", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseResponse.FromString, + ) + self.CreateTenant = channel.unary_unary( + "/chroma.SysDB/CreateTenant", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.GetTenant = channel.unary_unary( + "/chroma.SysDB/GetTenant", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantResponse.FromString, + ) self.CreateSegment = channel.unary_unary( "/chroma.SysDB/CreateSegment", request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, @@ -66,6 +86,30 @@ def __init__(self, channel): class SysDBServicer(object): """Missing associated documentation comment in .proto file.""" + def CreateDatabase(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def GetDatabase(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def CreateTenant(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def GetTenant(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def CreateSegment(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -123,6 +167,26 @@ def ResetState(self, request, context): def add_SysDBServicer_to_server(servicer, server): rpc_method_handlers = { + "CreateDatabase": grpc.unary_unary_rpc_method_handler( + servicer.CreateDatabase, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "GetDatabase": grpc.unary_unary_rpc_method_handler( + servicer.GetDatabase, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseRequest.FromString, + response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseResponse.SerializeToString, + ), + "CreateTenant": grpc.unary_unary_rpc_method_handler( + servicer.CreateTenant, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "GetTenant": grpc.unary_unary_rpc_method_handler( + servicer.GetTenant, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantRequest.FromString, + response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantResponse.SerializeToString, + ), "CreateSegment": grpc.unary_unary_rpc_method_handler( servicer.CreateSegment, request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.FromString, @@ -179,6 +243,122 @@ def add_SysDBServicer_to_server(servicer, server): class SysDB(object): """Missing associated documentation comment in .proto file.""" + @staticmethod + def CreateDatabase( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/CreateDatabase", + chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def GetDatabase( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/GetDatabase", + chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseRequest.SerializeToString, + chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def CreateTenant( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/CreateTenant", + chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def GetTenant( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/GetTenant", + chromadb_dot_proto_dot_coordinator__pb2.GetTenantRequest.SerializeToString, + chromadb_dot_proto_dot_coordinator__pb2.GetTenantResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def CreateSegment( request, diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 4921392d3ee..b66bf33bda6 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -15,9 +15,10 @@ FastAPIChromaAuthMiddleware, FastAPIChromaAuthMiddlewareWrapper, ) -from chromadb.config import Settings +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System import chromadb.server import chromadb.api +from chromadb.api import ServerAPI from chromadb.errors import ( ChromaError, InvalidUUIDError, @@ -25,6 +26,8 @@ ) from chromadb.server.fastapi.types import ( AddEmbedding, + CreateDatabase, + CreateTenant, DeleteEmbedding, GetEmbedding, QueryEmbedding, @@ -35,6 +38,7 @@ from starlette.requests import Request import logging +from chromadb.types import Database, Tenant from chromadb.telemetry.product import ServerContext, ProductTelemetryClient from chromadb.telemetry.opentelemetry import ( OpenTelemetryClient, @@ -109,8 +113,10 @@ def __init__(self, settings: Settings): super().__init__(settings) ProductTelemetryClient.SERVER_CONTEXT = ServerContext.FASTAPI self._app = fastapi.FastAPI(debug=True) - self._api: chromadb.api.API = chromadb.Client(settings) + self._system = System(settings) + self._api: ServerAPI = self._system.instance(ServerAPI) self._opentelemetry_client = self._api.require(OpenTelemetryClient) + self._system.start() self._app.middleware("http")(catch_exceptions_middleware) self._app.add_middleware( @@ -136,6 +142,34 @@ def __init__(self, settings: Settings): "/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"] ) + self.router.add_api_route( + "/api/v1/databases", + self.create_database, + methods=["POST"], + response_model=None, + ) + + self.router.add_api_route( + "/api/v1/databases/{database}", + self.get_database, + methods=["GET"], + response_model=None, + ) + + self.router.add_api_route( + "/api/v1/tenants", + self.create_tenant, + methods=["POST"], + response_model=None, + ) + + self.router.add_api_route( + "/api/v1/tenants/{tenant}", + self.get_tenant, + methods=["GET"], + response_model=None, + ) + self.router.add_api_route( "/api/v1/collections", self.list_collections, @@ -227,21 +261,55 @@ def heartbeat(self) -> Dict[str, int]: def version(self) -> str: return self._api.get_version() + @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION) + def create_database( + self, database: CreateDatabase, tenant: str = DEFAULT_TENANT + ) -> None: + return self._api.create_database(database.name, tenant) + + @trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION) + def get_database(self, database: str, tenant: str = DEFAULT_TENANT) -> Database: + return self._api.get_database(database, tenant) + + @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION) + def create_tenant(self, tenant: CreateTenant) -> None: + return self._api.create_tenant(tenant.name) + + @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION) + def get_tenant(self, tenant: str) -> Tenant: + return self._api.get_tenant(tenant) + @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) - def list_collections(self) -> Sequence[Collection]: - return self._api.list_collections() + def list_collections( + self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> Sequence[Collection]: + return self._api.list_collections(tenant=tenant, database=database) @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) - def create_collection(self, collection: CreateCollection) -> Collection: + def create_collection( + self, + collection: CreateCollection, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: return self._api.create_collection( name=collection.name, metadata=collection.metadata, get_or_create=collection.get_or_create, + tenant=tenant, + database=database, ) @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION) - def get_collection(self, collection_name: str) -> Collection: - return self._api.get_collection(collection_name) + def get_collection( + self, + collection_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + return self._api.get_collection( + collection_name, tenant=tenant, database=database + ) @trace_method("FastAPI.update_collection", OpenTelemetryGranularity.OPERATION) def update_collection( @@ -254,8 +322,15 @@ def update_collection( ) @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) - def delete_collection(self, collection_name: str) -> None: - return self._api.delete_collection(collection_name) + def delete_collection( + self, + collection_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: + return self._api.delete_collection( + collection_name, tenant=tenant, database=database + ) @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION) def add(self, collection_id: str, add: AddEmbedding) -> None: diff --git a/chromadb/server/fastapi/types.py b/chromadb/server/fastapi/types.py index 306f0e5fcb3..5f1665c91bd 100644 --- a/chromadb/server/fastapi/types.py +++ b/chromadb/server/fastapi/types.py @@ -59,3 +59,11 @@ class CreateCollection(BaseModel): # type: ignore class UpdateCollection(BaseModel): # type: ignore new_name: Optional[str] = None new_metadata: Optional[CollectionMetadata] = None + + +class CreateDatabase(BaseModel): + name: str + + +class CreateTenant(BaseModel): + name: str diff --git a/chromadb/test/auth/test_token_auth.py b/chromadb/test/auth/test_token_auth.py index 4e99baae306..50e88e296a9 100644 --- a/chromadb/test/auth/test_token_auth.py +++ b/chromadb/test/auth/test_token_auth.py @@ -5,7 +5,7 @@ import pytest from hypothesis import given, settings -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.config import System from chromadb.test.conftest import _fastapi_fixture @@ -64,7 +64,7 @@ def test_fastapi_server_token_auth(token_config: Dict[str, Any]) -> None: ) _sys: System = next(api) _sys.reset_state() - _api = _sys.instance(API) + _api = _sys.instance(ServerAPI) _api.heartbeat() assert _api.list_collections() == [] @@ -103,7 +103,7 @@ def test_invalid_token(tconf: Dict[str, Any], inval_tok: str) -> None: with pytest.raises(Exception) as e: _sys: System = next(api) _sys.reset_state() - _sys.instance(API) + _sys.instance(ServerAPI) assert "Invalid token" in str(e) @@ -131,7 +131,7 @@ def test_fastapi_server_token_auth_wrong_token( ) _sys: System = next(api) _sys.reset_state() - _api = _sys.instance(API) + _api = _sys.instance(ServerAPI) _api.heartbeat() with pytest.raises(Exception) as e: _api.list_collections() diff --git a/chromadb/test/client/test_database_tenant.py b/chromadb/test/client/test_database_tenant.py new file mode 100644 index 00000000000..00c502bb0ea --- /dev/null +++ b/chromadb/test/client/test_database_tenant.py @@ -0,0 +1,77 @@ +import pytest +from chromadb.api.client import AdminClient, Client +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT + + +def test_database_tenant_collections(client: Client) -> None: + client.reset() + # Create a new database in the default tenant + admin_client = AdminClient.from_system(client._system) + admin_client.create_database("test_db") + + # Create collections in this new database + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") + client.create_collection("collection", metadata={"database": "test_db"}) + + # Create collections in the default database + client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE) + client.create_collection("collection", metadata={"database": DEFAULT_DATABASE}) + + # List collections in the default database + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].name == "collection" + assert collections[0].metadata == {"database": DEFAULT_DATABASE} + + # List collections in the new database + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].metadata == {"database": "test_db"} + + # Update the metadata in both databases to different values + client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE) + client.list_collections()[0].modify(metadata={"database": "default2"}) + + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") + client.list_collections()[0].modify(metadata={"database": "test_db2"}) + + # Validate that the metadata was updated + client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE) + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].metadata == {"database": "default2"} + + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].metadata == {"database": "test_db2"} + + # Delete the collections and make sure databases are isolated + client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE) + client.delete_collection("collection") + + collections = client.list_collections() + assert len(collections) == 0 + + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") + collections = client.list_collections() + assert len(collections) == 1 + + client.delete_collection("collection") + collections = client.list_collections() + assert len(collections) == 0 + + +def test_min_len_name(client: Client) -> None: + client.reset() + + # Create a new database in the default tenant with a name of length 1 + # and expect an error + admin_client = AdminClient.from_system(client._system) + with pytest.raises(Exception): + admin_client.create_database("a") + + # Create a tenant with a name of length 1 and expect an error + with pytest.raises(Exception): + admin_client.create_tenant("a") diff --git a/chromadb/test/client/test_multiple_clients_concurrency.py b/chromadb/test/client/test_multiple_clients_concurrency.py new file mode 100644 index 00000000000..14054214cbf --- /dev/null +++ b/chromadb/test/client/test_multiple_clients_concurrency.py @@ -0,0 +1,47 @@ +from concurrent.futures import ThreadPoolExecutor +from chromadb.api.client import AdminClient, Client +from chromadb.config import DEFAULT_TENANT + + +def test_multiple_clients_concurrently(client: Client) -> None: + """Tests running multiple clients, each against their own database, concurrently.""" + client.reset() + admin_client = AdminClient.from_system(client._system) + admin_client.create_database("test_db") + + CLIENT_COUNT = 50 + COLLECTION_COUNT = 10 + + # Each database will create the same collections by name, with differing metadata + databases = [f"db{i}" for i in range(CLIENT_COUNT)] + for database in databases: + admin_client.create_database(database) + + collections = [f"collection{i}" for i in range(COLLECTION_COUNT)] + + # Create N clients, each on a seperate thread, each with their own database + def run_target(n: int) -> None: + thread_client = Client( + tenant=DEFAULT_TENANT, + database=databases[n], + settings=client._system.settings, + ) + for collection in collections: + thread_client.create_collection( + collection, metadata={"database": databases[n]} + ) + + with ThreadPoolExecutor(max_workers=CLIENT_COUNT) as executor: + executor.map(run_target, range(CLIENT_COUNT)) + + # Create a final client, which will be used to verify the collections were created + client = Client(settings=client._system.settings) + + # Verify that the collections were created + for database in databases: + client.set_database(database) + seen_collections = client.list_collections() + assert len(seen_collections) == COLLECTION_COUNT + for collection in seen_collections: + assert collection.name in collections + assert collection.metadata == {"database": database} diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index af66ef2513f..b27e1af7c01 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -22,11 +22,12 @@ from typing_extensions import Protocol import chromadb.server.fastapi -from chromadb.api import API +from chromadb.api import ClientAPI, ServerAPI from chromadb.config import Settings, System from chromadb.db.mixins import embeddings_queue from chromadb.ingest import Producer from chromadb.types import SeqId, SubmitEmbeddingRecord +from chromadb.api.client import Client as ClientCreator root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) # This will only run when testing @@ -98,7 +99,7 @@ def _run_server( uvicorn.run(server.app(), host="0.0.0.0", port=port, log_level="error") -def _await_server(api: API, attempts: int = 0) -> None: +def _await_server(api: ServerAPI, attempts: int = 0) -> None: try: api.heartbeat() except ConnectionError as e: @@ -172,7 +173,7 @@ def _fastapi_fixture( chroma_client_auth_token_transport_header=chroma_client_auth_token_transport_header, ) system = System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) system.start() _await_server(api) yield system @@ -198,7 +199,7 @@ def basic_http_client() -> Generator[System, None, None]: allow_reset=True, ) system = System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) _await_server(api) system.start() yield system @@ -361,41 +362,50 @@ def system_fixtures_wrong_auth() -> List[Callable[[], Generator[System, None, No @pytest.fixture(scope="module", params=system_fixtures_wrong_auth()) -def system_wrong_auth(request: pytest.FixtureRequest) -> Generator[API, None, None]: +def system_wrong_auth( + request: pytest.FixtureRequest, +) -> Generator[ServerAPI, None, None]: yield next(request.param()) @pytest.fixture(scope="module", params=system_fixtures()) -def system(request: pytest.FixtureRequest) -> Generator[API, None, None]: +def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: yield next(request.param()) @pytest.fixture(scope="module", params=system_fixtures_auth()) -def system_auth(request: pytest.FixtureRequest) -> Generator[API, None, None]: +def system_auth(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: yield next(request.param()) @pytest.fixture(scope="function") -def api(system: System) -> Generator[API, None, None]: +def api(system: System) -> Generator[ServerAPI, None, None]: system.reset_state() - api = system.instance(API) + api = system.instance(ServerAPI) yield api +@pytest.fixture(scope="function") +def client(system: System) -> Generator[ClientAPI, None, None]: + system.reset_state() + client = ClientCreator.from_system(system) + yield client + + @pytest.fixture(scope="function") def api_wrong_cred( system_wrong_auth: System, -) -> Generator[API, None, None]: +) -> Generator[ServerAPI, None, None]: system_wrong_auth.reset_state() - api = system_wrong_auth.instance(API) + api = system_wrong_auth.instance(ServerAPI) yield api @pytest.fixture(scope="function") -def api_with_server_auth(system_auth: System) -> Generator[API, None, None]: +def api_with_server_auth(system_auth: System) -> Generator[ServerAPI, None, None]: _sys = system_auth _sys.reset_state() - api = _sys.instance(API) + api = _sys.instance(ServerAPI) yield api diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 541643a2ff6..e801618b4bc 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -3,11 +3,18 @@ import tempfile import pytest from typing import Generator, List, Callable, Dict, Union + from chromadb.db.impl.grpc.client import GrpcSysDB from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB -from chromadb.config import Component, System, Settings +from chromadb.config import ( + DEFAULT_DATABASE, + DEFAULT_TENANT, + Component, + System, + Settings, +) from chromadb.db.system import SysDB from chromadb.db.base import NotFoundError, UniqueConstraintError from pytest import FixtureRequest @@ -107,6 +114,7 @@ def sysdb(request: FixtureRequest) -> Generator[SysDB, None, None]: yield next(request.param()) +# region Collection tests def test_create_get_delete_collections(sysdb: SysDB) -> None: sysdb.reset_state() @@ -296,6 +304,275 @@ def test_get_or_create_collection(sysdb: SysDB) -> None: assert result["metadata"] == overlayed_metadata +def test_create_get_delete_database_and_collection(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new database + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create a new collection in the new database + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + ) + + # Create a new collection with the same id but different name in the new database + # and expect an error + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[0]["id"], + name="new_name", + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + get_or_create=False, + ) + + # Create a new collection in the default database + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + ) + + # Check that the new database and collections exist + result = sysdb.get_collections( + name=sample_collections[0]["name"], database="new_database" + ) + assert len(result) == 1 + assert result[0] == sample_collections[0] + + # Check that the collection in the default database exists + result = sysdb.get_collections(name=sample_collections[1]["name"]) + assert len(result) == 1 + assert result[0] == sample_collections[1] + + # Get for a database that doesn't exist with a name that exists in the new database and expect no results + assert ( + len( + sysdb.get_collections( + name=sample_collections[0]["name"], database="fake_db" + ) + ) + == 0 + ) + + # Delete the collection in the new database + sysdb.delete_collection(id=sample_collections[0]["id"], database="new_database") + + # Check that the collection in the new database was deleted + result = sysdb.get_collections(database="new_database") + assert len(result) == 0 + + # Check that the collection in the default database still exists + result = sysdb.get_collections(name=sample_collections[1]["name"]) + assert len(result) == 1 + assert result[0] == sample_collections[1] + + # Delete the deleted collection in the default database and expect an error + with pytest.raises(NotFoundError): + sysdb.delete_collection(id=sample_collections[0]["id"]) + + # Delete the existing collection in the new database and expect an error + with pytest.raises(NotFoundError): + sysdb.delete_collection(id=sample_collections[1]["id"], database="new_database") + + +def test_create_update_with_database(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new database + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create a new collection in the new database + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + ) + + # Create a new collection in the default database + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + ) + + # Update the collection in the default database + sysdb.update_collection( + id=sample_collections[1]["id"], + name="new_name_1", + ) + + # Check that the collection in the default database was updated + result = sysdb.get_collections(id=sample_collections[1]["id"]) + assert len(result) == 1 + assert result[0]["name"] == "new_name_1" + + # Update the collection in the new database + sysdb.update_collection( + id=sample_collections[0]["id"], + name="new_name_0", + ) + + # Check that the collection in the new database was updated + result = sysdb.get_collections( + id=sample_collections[0]["id"], database="new_database" + ) + assert len(result) == 1 + assert result[0]["name"] == "new_name_0" + + # Try to create the collection in the default database in the new database and expect an error + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + database="new_database", + ) + + +def test_get_multiple_with_database(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new database + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create sample collections in the new database + for collection in sample_collections: + sysdb.create_collection( + id=collection["id"], + name=collection["name"], + metadata=collection["metadata"], + dimension=collection["dimension"], + database="new_database", + ) + + # Get all collections in the new database + result = sysdb.get_collections(database="new_database") + assert len(result) == len(sample_collections) + assert sorted(result, key=lambda c: c["name"]) == sample_collections + + # Get all collections in the default database + result = sysdb.get_collections() + assert len(result) == 0 + + +def test_create_database_with_tenants(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new tenant + sysdb.create_tenant(name="tenant1") + + # Create tenant that already exits and expect an error + with pytest.raises(UniqueConstraintError): + sysdb.create_tenant(name="tenant1") + + with pytest.raises(UniqueConstraintError): + sysdb.create_tenant(name=DEFAULT_TENANT) + + # Create a new database within this tenant and also in the default tenant + sysdb.create_database(id=uuid.uuid4(), name="new_database", tenant="tenant1") + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create a new collection in the new tenant + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + tenant="tenant1", + ) + + # Create a new collection in the default tenant + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + database="new_database", + ) + + # Check that both tenants have the correct collections + result = sysdb.get_collections(database="new_database", tenant="tenant1") + assert len(result) == 1 + assert result[0] == sample_collections[0] + + result = sysdb.get_collections(database="new_database") + assert len(result) == 1 + assert result[0] == sample_collections[1] + + # Creating a collection id that already exists in a tenant that does not have it + # should error + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + ) + + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + database="new_database", + tenant="tenant1", + ) + + # A new tenant DOES NOT have a default database. This does not error, instead 0 + # results are returned + result = sysdb.get_collections(database=DEFAULT_DATABASE, tenant="tenant1") + assert len(result) == 0 + + +def test_get_database_with_tenants(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new tenant + sysdb.create_tenant(name="tenant1") + + # Get the tenant and check that it exists + result = sysdb.get_tenant(name="tenant1") + assert result["name"] == "tenant1" + + # Get a tenant that does not exist and expect an error + with pytest.raises(NotFoundError): + sysdb.get_tenant(name="tenant2") + + # Create a new database within this tenant + sysdb.create_database(id=uuid.uuid4(), name="new_database", tenant="tenant1") + + # Get the database and check that it exists + result = sysdb.get_database(name="new_database", tenant="tenant1") + assert result["name"] == "new_database" + assert result["tenant"] == "tenant1" + + # Get a database that does not exist in a tenant that does exist and expect an error + with pytest.raises(NotFoundError): + sysdb.get_database(name="new_database1", tenant="tenant1") + + # Get a database that does not exist in a tenant that does not exist and expect an + # error + with pytest.raises(NotFoundError): + sysdb.get_database(name="new_database1", tenant="tenant2") + + +# endregion + +# region Segment tests sample_segments = [ Segment( id=uuid.UUID("00000000-d7d7-413b-92e1-731098a6e492"), @@ -459,3 +736,6 @@ def test_update_segment(sysdb: SysDB) -> None: sysdb.update_segment(segment["id"], metadata=None) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] + + +# endregion diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index e8540ef37aa..3583dadfba9 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -95,10 +95,12 @@ class Record(TypedDict): # TODO: support empty strings everywhere sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_" safe_text = st.text(alphabet=sql_alphabet, min_size=1) +tenant_database_name = st.text(alphabet=sql_alphabet, min_size=3) # Workaround for FastAPI json encoding peculiarities # https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104 safe_text = safe_text.filter(lambda s: not s.startswith("_sa")) +tenant_database_name = tenant_database_name.filter(lambda s: not s.startswith("_sa")) safe_integers = st.integers( min_value=-(2**31), max_value=2**31 - 1 diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index 1980ed2a9d9..5f8991b00ed 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -5,7 +5,7 @@ import pytest import hypothesis.strategies as st from hypothesis import given, settings -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.types import Embeddings, Metadatas import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants @@ -17,7 +17,7 @@ @given(collection=collection_st, record_set=strategies.recordsets(collection_st)) @settings(deadline=None) def test_add( - api: API, + api: ServerAPI, collection: strategies.Collection, record_set: strategies.RecordSet, ) -> None: @@ -69,7 +69,7 @@ def create_large_recordset( @given(collection=collection_st) @settings(deadline=None, max_examples=1) -def test_add_large(api: API, collection: strategies.Collection) -> None: +def test_add_large(api: ServerAPI, collection: strategies.Collection) -> None: api.reset() record_set = create_large_recordset( min_size=api.max_batch_size, @@ -99,7 +99,7 @@ def test_add_large(api: API, collection: strategies.Collection) -> None: @given(collection=collection_st) @settings(deadline=None, max_examples=1) -def test_add_large_exceeding(api: API, collection: strategies.Collection) -> None: +def test_add_large_exceeding(api: ServerAPI, collection: strategies.Collection) -> None: api.reset() record_set = create_large_recordset( min_size=api.max_batch_size, @@ -126,7 +126,7 @@ def test_add_large_exceeding(api: API, collection: strategies.Collection) -> Non reason="This is expected to fail right now. We should change the API to sort the \ ids by input order." ) -def test_out_of_order_ids(api: API) -> None: +def test_out_of_order_ids(api: ServerAPI) -> None: api.reset() ooo_ids = [ "40", @@ -165,7 +165,7 @@ def test_out_of_order_ids(api: API) -> None: assert get_ids == ooo_ids -def test_add_partial(api: API) -> None: +def test_add_partial(api: ServerAPI) -> None: """Tests adding a record set with some of the fields set to None.""" api.reset() diff --git a/chromadb/test/property/test_collections.py b/chromadb/test/property/test_collections.py index 60e3de7592c..d7cd8492117 100644 --- a/chromadb/test/property/test_collections.py +++ b/chromadb/test/property/test_collections.py @@ -2,7 +2,7 @@ import logging import hypothesis.strategies as st import chromadb.test.property.strategies as strategies -from chromadb.api import API +from chromadb.api import ClientAPI import chromadb.api.types as types from hypothesis.stateful import ( Bundle, @@ -19,19 +19,19 @@ class CollectionStateMachine(RuleBasedStateMachine): collections: Bundle[strategies.Collection] - model: Dict[str, Optional[types.CollectionMetadata]] + _model: Dict[str, Optional[types.CollectionMetadata]] collections = Bundle("collections") - def __init__(self, api: API): + def __init__(self, api: ClientAPI): super().__init__() - self.model = {} + self._model = {} self.api = api @initialize() def initialize(self) -> None: self.api.reset() - self.model = {} + self._model = {} @rule(target=collections, coll=strategies.collections()) def create_coll( @@ -54,7 +54,7 @@ def create_coll( metadata=coll.metadata, embedding_function=coll.embedding_function, ) - self.model[coll.name] = coll.metadata + self.set_model(coll.name, coll.metadata) assert c.name == coll.name assert c.metadata == coll.metadata @@ -74,7 +74,7 @@ def get_coll(self, coll: strategies.Collection) -> None: def delete_coll(self, coll: strategies.Collection) -> None: if coll.name in self.model: self.api.delete_collection(name=coll.name) - del self.model[coll.name] + self.delete_from_model(coll.name) else: with pytest.raises(Exception): self.api.delete_collection(name=coll.name) @@ -140,7 +140,7 @@ def get_or_create_coll( coll.metadata = ( self.model[coll.name] if new_metadata is None else new_metadata ) - self.model[coll.name] = coll.metadata + self.set_model(coll.name, coll.metadata) # Update API c = self.api.get_or_create_collection( @@ -183,7 +183,7 @@ def modify_coll( ) return multiple() coll.metadata = new_metadata - self.model[coll.name] = coll.metadata + self.set_model(coll.name, coll.metadata) if new_name is not None: if new_name in self.model and new_name != coll.name: @@ -191,8 +191,8 @@ def modify_coll( c.modify(metadata=new_metadata, name=new_name) return multiple() - del self.model[coll.name] - self.model[new_name] = coll.metadata + self.delete_from_model(coll.name) + self.set_model(new_name, coll.metadata) coll.name = new_name c.modify(metadata=new_metadata, name=new_name) @@ -202,7 +202,21 @@ def modify_coll( assert c.metadata == coll.metadata return multiple(coll) + def set_model( + self, name: str, metadata: Optional[types.CollectionMetadata] + ) -> None: + model = self.model + model[name] = metadata -def test_collections(caplog: pytest.LogCaptureFixture, api: API) -> None: + def delete_from_model(self, name: str) -> None: + model = self.model + del model[name] + + @property + def model(self) -> Dict[str, Optional[types.CollectionMetadata]]: + return self._model + + +def test_collections(caplog: pytest.LogCaptureFixture, api: ClientAPI) -> None: caplog.set_level(logging.ERROR) run_state_machine_as_test(lambda: CollectionStateMachine(api)) # type: ignore diff --git a/chromadb/test/property/test_collections_with_database_tenant.py b/chromadb/test/property/test_collections_with_database_tenant.py new file mode 100644 index 00000000000..28ba14f092a --- /dev/null +++ b/chromadb/test/property/test_collections_with_database_tenant.py @@ -0,0 +1,102 @@ +import logging +from typing import Dict, Optional, Tuple +import pytest +from chromadb.api import AdminAPI +import chromadb.api.types as types +from chromadb.api.client import AdminClient, Client +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT +from chromadb.test.property.test_collections import CollectionStateMachine +from hypothesis.stateful import ( + Bundle, + rule, + initialize, + multiple, + run_state_machine_as_test, + MultipleResults, +) +import chromadb.test.property.strategies as strategies + + +class TenantDatabaseCollectionStateMachine(CollectionStateMachine): + """A collection state machine test that includes tenant and database information, + and switches between them.""" + + tenants: Bundle[str] + databases: Bundle[Tuple[str, str]] # database to tenant it belongs to + tenant_to_database_to_model: Dict[ + str, Dict[str, Dict[str, Optional[types.CollectionMetadata]]] + ] + admin_client: AdminAPI + curr_tenant: str + curr_database: str + + tenants = Bundle("tenants") + databases = Bundle("databases") + + def __init__(self, client: Client): + super().__init__(client) + self.api = client + self.admin_client = AdminClient.from_system(client._system) + + @initialize() + def initialize(self) -> None: + self.api.reset() + self.tenant_to_database_to_model = {} + self.curr_tenant = DEFAULT_TENANT + self.curr_database = DEFAULT_DATABASE + self.api.set_tenant(DEFAULT_TENANT, DEFAULT_DATABASE) + self.tenant_to_database_to_model[self.curr_tenant] = {} + self.tenant_to_database_to_model[self.curr_tenant][self.curr_database] = {} + + @rule(target=tenants, name=strategies.tenant_database_name) + def create_tenant(self, name: str) -> MultipleResults[str]: + # Check if tenant already exists + if name in self.tenant_to_database_to_model: + with pytest.raises(Exception): + self.admin_client.create_tenant(name) + return multiple() + + self.admin_client.create_tenant(name) + # When we create a tenant, create a default database for it just for testing + # since the state machine could call collection operations before creating a + # database + self.admin_client.create_database(DEFAULT_DATABASE, tenant=name) + self.tenant_to_database_to_model[name] = {} + self.tenant_to_database_to_model[name][DEFAULT_DATABASE] = {} + return multiple(name) + + @rule(target=databases, name=strategies.tenant_database_name) + def create_database(self, name: str) -> MultipleResults[Tuple[str, str]]: + # If database already exists in current tenant, raise an error + if name in self.tenant_to_database_to_model[self.curr_tenant]: + with pytest.raises(Exception): + self.admin_client.create_database(name, tenant=self.curr_tenant) + return multiple() + + self.admin_client.create_database(name, tenant=self.curr_tenant) + self.tenant_to_database_to_model[self.curr_tenant][name] = {} + return multiple((name, self.curr_tenant)) + + @rule(database=databases) + def set_database_and_tenant(self, database: Tuple[str, str]) -> None: + # Get a database and switch to the database and the tenant it belongs to + database_name = database[0] + tenant_name = database[1] + self.api.set_tenant(tenant_name, database_name) + self.curr_database = database_name + self.curr_tenant = tenant_name + + @rule(tenant=tenants) + def set_tenant(self, tenant: str) -> None: + self.api.set_tenant(tenant, DEFAULT_DATABASE) + self.curr_tenant = tenant + self.curr_database = DEFAULT_DATABASE + + @property + def model(self) -> Dict[str, Optional[types.CollectionMetadata]]: + return self.tenant_to_database_to_model[self.curr_tenant][self.curr_database] + + +def test_collections(caplog: pytest.LogCaptureFixture, client: Client) -> None: + caplog.set_level(logging.ERROR) + run_state_machine_as_test(lambda: TenantDatabaseCollectionStateMachine(client)) # type: ignore diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 3bd83231b32..b5320dfe7a9 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -5,14 +5,14 @@ import subprocess import tempfile from types import ModuleType -from typing import Generator, List, Tuple, Dict, Any, Callable +from typing import Generator, List, Tuple, Dict, Any, Callable, Type from hypothesis import given, settings import hypothesis.strategies as st import pytest import json from urllib import request from chromadb import config -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.types import Documents, EmbeddingFunction, Embeddings import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants @@ -100,6 +100,12 @@ def patch_for_version( patch(collection, embeddings, settings) +def api_import_for_version(module: Any, version: str) -> Type: # type: ignore + if packaging_version.Version(version) <= packaging_version.Version("0.4.14"): + return module.api.API # type: ignore + return module.api.ServerAPI # type: ignore + + def configurations(versions: List[str]) -> List[Tuple[str, Settings]]: return [ ( @@ -213,13 +219,13 @@ def persist_generated_data_with_old_version( try: old_module = switch_to_version(version) system = old_module.config.System(settings) - api: API = system.instance(API) + api = system.instance(api_import_for_version(old_module, version)) system.start() api.reset() coll = api.create_collection( name=collection_strategy.name, - metadata=collection_strategy.metadata, # type: ignore + metadata=collection_strategy.metadata, # In order to test old versions, we can't rely on the not_implemented function embedding_function=not_implemented_ef(), ) @@ -304,7 +310,7 @@ def test_cycle_versions( # Switch to the current version (local working directory) and check the invariants # are preserved for the collection system = config.System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) system.start() coll = api.get_collection( name=collection_strategy.name, diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 0e402cca1a8..7fc2491c14b 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from chromadb.api.types import ID, Include, IDs import chromadb.errors as errors -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection import chromadb.test.property.strategies as strategies from hypothesis.stateful import ( @@ -64,7 +64,7 @@ class EmbeddingStateMachine(RuleBasedStateMachine): collection: Collection embedding_ids: Bundle[ID] = Bundle("embedding_ids") - def __init__(self, api: API): + def __init__(self, api: ServerAPI): super().__init__() self.api = api self._rules_strategy = strategies.DeterministicRuleStrategy(self) # type: ignore @@ -294,13 +294,13 @@ def on_state_change(self, new_state: str) -> None: pass -def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: API) -> None: +def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: ServerAPI) -> None: caplog.set_level(logging.ERROR) run_state_machine_as_test(lambda: EmbeddingStateMachine(api)) # type: ignore print_traces() -def test_multi_add(api: API) -> None: +def test_multi_add(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") coll.add(ids=["a"], embeddings=[[0.0]]) @@ -319,7 +319,7 @@ def test_multi_add(api: API) -> None: assert coll.count() == 0 -def test_dup_add(api: API) -> None: +def test_dup_add(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") with pytest.raises(errors.DuplicateIDError): @@ -328,7 +328,7 @@ def test_dup_add(api: API) -> None: coll.upsert(ids=["a", "a"], embeddings=[[0.0], [1.1]]) -def test_query_without_add(api: API) -> None: +def test_query_without_add(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") fields: Include = ["documents", "metadatas", "embeddings", "distances"] @@ -343,7 +343,7 @@ def test_query_without_add(api: API) -> None: assert all([len(result) == 0 for result in field_results]) -def test_get_non_existent(api: API) -> None: +def test_get_non_existent(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") result = coll.get(ids=["a"], include=["documents", "metadatas", "embeddings"]) @@ -355,7 +355,7 @@ def test_get_non_existent(api: API) -> None: # TODO: Use SQL escaping correctly internally @pytest.mark.xfail(reason="We don't properly escape SQL internally, causing problems") -def test_escape_chars_in_ids(api: API) -> None: +def test_escape_chars_in_ids(api: ServerAPI) -> None: api.reset() id = "\x1f" coll = api.create_collection(name="foo") @@ -375,7 +375,7 @@ def test_escape_chars_in_ids(api: API) -> None: {"where_document": {}, "where": {}}, ], ) -def test_delete_empty_fails(api: API, kwargs: dict): +def test_delete_empty_fails(api: ServerAPI, kwargs: dict): api.reset() coll = api.create_collection(name="foo") with pytest.raises(Exception) as e: @@ -398,7 +398,7 @@ def test_delete_empty_fails(api: API, kwargs: dict): }, ], ) -def test_delete_success(api: API, kwargs: dict): +def test_delete_success(api: ServerAPI, kwargs: dict): api.reset() coll = api.create_collection(name="foo") # Should not raise diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index ddcdefb0ed3..e55e5d18cf5 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, cast from hypothesis import given, settings, HealthCheck import pytest -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.test.property import invariants from chromadb.api.types import ( Document, @@ -165,7 +165,7 @@ def _filter_embedding_set( filters=st.lists(strategies.filters(collection_st, recordset_st), min_size=1), ) def test_filterable_metadata_get( - caplog, api: API, collection: strategies.Collection, record_set, filters + caplog, api: ServerAPI, collection: strategies.Collection, record_set, filters ) -> None: caplog.set_level(logging.ERROR) @@ -204,7 +204,7 @@ def test_filterable_metadata_get( ) def test_filterable_metadata_query( caplog: pytest.LogCaptureFixture, - api: API, + api: ServerAPI, collection: strategies.Collection, record_set: strategies.RecordSet, filters: List[strategies.Filter], @@ -257,7 +257,7 @@ def test_filterable_metadata_query( assert len(result_ids.intersection(expected_ids)) == len(result_ids) -def test_empty_filter(api: API) -> None: +def test_empty_filter(api: ServerAPI) -> None: """Test that a filter where no document matches returns an empty result""" api.reset() coll = api.create_collection(name="test") @@ -291,7 +291,7 @@ def test_empty_filter(api: API) -> None: assert res["metadatas"] == [[], []] -def test_boolean_metadata(api: API) -> None: +def test_boolean_metadata(api: ServerAPI) -> None: """Test that metadata with boolean values is correctly filtered""" api.reset() coll = api.create_collection(name="test") diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index ea95f684f60..e7b1f7017d1 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -6,7 +6,7 @@ import hypothesis.strategies as st import pytest import chromadb -from chromadb.api import API +from chromadb.api import ClientAPI, ServerAPI from chromadb.config import Settings, System import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants @@ -26,7 +26,7 @@ import shutil import tempfile -CreatePersistAPI = Callable[[], API] +CreatePersistAPI = Callable[[], ServerAPI] configurations = [ Settings( @@ -71,7 +71,7 @@ def test_persist( embeddings_strategy: strategies.RecordSet, ) -> None: system_1 = System(settings) - api_1 = system_1.instance(API) + api_1 = system_1.instance(ServerAPI) system_1.start() api_1.reset() @@ -103,7 +103,7 @@ def test_persist( del system_1 system_2 = System(settings) - api_2 = system_2.instance(API) + api_2 = system_2.instance(ServerAPI) system_2.start() coll = api_2.get_collection( @@ -133,7 +133,7 @@ def load_and_check( ) -> None: try: system = System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) system.start() coll = api.get_collection( @@ -157,7 +157,7 @@ class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates): class PersistEmbeddingsStateMachine(EmbeddingStateMachine): - def __init__(self, api: API, settings: Settings): + def __init__(self, api: ClientAPI, settings: Settings): self.api = api self.settings = settings self.last_persist_delay = 10 diff --git a/chromadb/test/stress/test_many_collections.py b/chromadb/test/stress/test_many_collections.py index 7e65c4b790d..29951fa452a 100644 --- a/chromadb/test/stress/test_many_collections.py +++ b/chromadb/test/stress/test_many_collections.py @@ -1,11 +1,11 @@ from typing import List import numpy as np -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection -def test_many_collections(api: API) -> None: +def test_many_collections(api: ServerAPI) -> None: """Test that we can create a large number of collections and that the system # remains responsive.""" api.reset() diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 8a12a1d9735..ed3c87ee682 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -21,7 +21,7 @@ @pytest.fixture def local_persist_api(): - yield chromadb.Client( + client = chromadb.Client( Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", @@ -33,6 +33,8 @@ def local_persist_api(): persist_directory=persist_dir, ), ) + yield client + client.clear_system_cache() if os.path.exists(persist_dir): shutil.rmtree(persist_dir, ignore_errors=True) @@ -40,7 +42,7 @@ def local_persist_api(): # https://docs.pytest.org/en/6.2.x/fixture.html#fixtures-can-be-requested-more-than-once-per-test-return-values-are-cached @pytest.fixture def local_persist_api_cache_bust(): - yield chromadb.Client( + client = chromadb.Client( Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", @@ -52,6 +54,8 @@ def local_persist_api_cache_bust(): persist_directory=persist_dir, ), ) + yield client + client.clear_system_cache() if os.path.exists(persist_dir): shutil.rmtree(persist_dir, ignore_errors=True) diff --git a/chromadb/test/test_chroma.py b/chromadb/test/test_chroma.py index 42b14411519..9d88ea8cc49 100644 --- a/chromadb/test/test_chroma.py +++ b/chromadb/test/test_chroma.py @@ -47,19 +47,21 @@ class GetAPITest(unittest.TestCase): @patch("chromadb.api.segment.SegmentAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_local(self, mock_api: Mock) -> None: - chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) + client = chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) assert mock_api.called + client.clear_system_cache() @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_local_db(self, mock_db: Mock) -> None: - chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) + client = chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) assert mock_db.called + client.clear_system_cache() @patch("chromadb.api.fastapi.FastAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_fastapi(self, mock: Mock) -> None: - chromadb.Client( + client = chromadb.Client( chromadb.config.Settings( chroma_api_impl="chromadb.api.fastapi.FastAPI", persist_directory="./foo", @@ -68,6 +70,7 @@ def test_fastapi(self, mock: Mock) -> None: ) ) assert mock.called + client.clear_system_cache() @patch("chromadb.api.fastapi.FastAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) @@ -78,7 +81,7 @@ def test_settings_pass_to_fastapi(self, mock: Mock) -> None: chroma_server_http_port="80", chroma_server_headers={"foo": "bar"}, ) - chromadb.Client(settings) + client = chromadb.Client(settings) # Check that the mock was called assert mock.called @@ -93,11 +96,12 @@ def test_settings_pass_to_fastapi(self, mock: Mock) -> None: # Check if the settings passed to the mock match the settings we used # raise Exception(passed_settings.settings) assert passed_settings.settings == settings + client.clear_system_cache() def test_legacy_values() -> None: with pytest.raises(ValueError): - chromadb.Client( + client = chromadb.Client( chromadb.config.Settings( chroma_api_impl="chromadb.api.local.LocalAPI", persist_directory="./foo", @@ -105,3 +109,4 @@ def test_legacy_values() -> None: chroma_server_http_port="80", ) ) + client.clear_system_cache() diff --git a/chromadb/test/test_client.py b/chromadb/test/test_client.py index 1164e1e699d..2a8aec2764e 100644 --- a/chromadb/test/test_client.py +++ b/chromadb/test/test_client.py @@ -1,37 +1,46 @@ +from typing import Generator +from unittest.mock import patch import chromadb -from chromadb.api import API +from chromadb.api import ClientAPI import chromadb.server.fastapi import pytest import tempfile @pytest.fixture -def ephemeral_api() -> API: - return chromadb.EphemeralClient() +def ephemeral_api() -> Generator[ClientAPI, None, None]: + client = chromadb.EphemeralClient() + yield client + client.clear_system_cache() @pytest.fixture -def persistent_api() -> API: - return chromadb.PersistentClient( +def persistent_api() -> Generator[ClientAPI, None, None]: + client = chromadb.PersistentClient( path=tempfile.gettempdir() + "/test_server", ) + yield client + client.clear_system_cache() @pytest.fixture -def http_api() -> API: - return chromadb.HttpClient() +def http_api() -> Generator[ClientAPI, None, None]: + with patch("chromadb.api.client.Client._validate_tenant_database"): + client = chromadb.HttpClient() + yield client + client.clear_system_cache() -def test_ephemeral_client(ephemeral_api: API) -> None: +def test_ephemeral_client(ephemeral_api: ClientAPI) -> None: settings = ephemeral_api.get_settings() assert settings.is_persistent is False -def test_persistent_client(persistent_api: API) -> None: +def test_persistent_client(persistent_api: ClientAPI) -> None: settings = persistent_api.get_settings() assert settings.is_persistent is True -def test_http_client(http_api: API) -> None: +def test_http_client(http_api: ClientAPI) -> None: settings = http_api.get_settings() assert settings.chroma_api_impl == "chromadb.api.fastapi.FastAPI" diff --git a/chromadb/test/test_multithreaded.py b/chromadb/test/test_multithreaded.py index 57c259dad99..c0b05e88324 100644 --- a/chromadb/test/test_multithreaded.py +++ b/chromadb/test/test_multithreaded.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, cast import numpy as np -from chromadb.api import API +from chromadb.api import ServerAPI import chromadb.test.property.invariants as invariants from chromadb.test.property.strategies import RecordSet from chromadb.test.property.strategies import test_hnsw_config @@ -37,7 +37,7 @@ def generate_record_set(N: int, D: int) -> RecordSet: # Hypothesis is bad at generating large datasets so we manually generate data in # this test to test multithreaded add with larger datasets -def _test_multithreaded_add(api: API, N: int, D: int, num_workers: int) -> None: +def _test_multithreaded_add(api: ServerAPI, N: int, D: int, num_workers: int) -> None: records_set = generate_record_set(N, D) ids = records_set["ids"] embeddings = records_set["embeddings"] @@ -95,7 +95,9 @@ def _test_multithreaded_add(api: API, N: int, D: int, num_workers: int) -> None: ) -def _test_interleaved_add_query(api: API, N: int, D: int, num_workers: int) -> None: +def _test_interleaved_add_query( + api: ServerAPI, N: int, D: int, num_workers: int +) -> None: """Test that will use multiple threads to interleave operations on the db and verify they work correctly""" api.reset() @@ -207,14 +209,14 @@ def perform_operation( ) -def test_multithreaded_add(api: API) -> None: +def test_multithreaded_add(api: ServerAPI) -> None: for i in range(3): num_workers = random.randint(2, multiprocessing.cpu_count() * 2) N, D = generate_data_shape() _test_multithreaded_add(api, N, D, num_workers) -def test_interleaved_add_query(api: API) -> None: +def test_interleaved_add_query(api: ServerAPI) -> None: for i in range(3): num_workers = random.randint(2, multiprocessing.cpu_count() * 2) N, D = generate_data_shape() diff --git a/chromadb/types.py b/chromadb/types.py index 713cab7757c..ed74b4f5e20 100644 --- a/chromadb/types.py +++ b/chromadb/types.py @@ -29,6 +29,16 @@ class Collection(TypedDict): dimension: Optional[int] +class Database(TypedDict): + id: UUID + name: str + tenant: str + + +class Tenant(TypedDict): + name: str + + class Segment(TypedDict): id: UUID type: NamespacedName diff --git a/chromadb/utils/batch_utils.py b/chromadb/utils/batch_utils.py index c8c1ac1e476..9c588270f25 100644 --- a/chromadb/utils/batch_utils.py +++ b/chromadb/utils/batch_utils.py @@ -1,5 +1,5 @@ from typing import Optional, Tuple, List -from chromadb.api import API +from chromadb.api import BaseAPI from chromadb.api.types import ( Documents, Embeddings, @@ -9,7 +9,7 @@ def create_batches( - api: API, + api: BaseAPI, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, diff --git a/idl/chromadb/proto/chroma.proto b/idl/chromadb/proto/chroma.proto index 7b1f10f18ea..5aaff218176 100644 --- a/idl/chromadb/proto/chroma.proto +++ b/idl/chromadb/proto/chroma.proto @@ -55,6 +55,16 @@ message Collection { optional int32 dimension = 5; } +message Database { + string id = 1; + string name = 2; + string tenant = 3; +} + +message Tenant { + string name = 1; +} + message UpdateMetadataValue { oneof value { string string_value = 1; diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 2a557f99613..0871f3f3c52 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -5,6 +5,35 @@ package chroma; import "chromadb/proto/chroma.proto"; import "google/protobuf/empty.proto"; +message CreateDatabaseRequest { + string id = 1; + string name = 2; + string tenant = 3; +} + +message GetDatabaseRequest { + string name = 1; + string tenant = 2; +} + +message GetDatabaseResponse { + Database database = 1; + Status status = 2; +} + +message CreateTenantRequest { + string name = 2; // Names are globally unique +} + +message GetTenantRequest { + string name = 1; +} + +message GetTenantResponse { + Tenant tenant = 1; + Status status = 2; +} + message CreateSegmentRequest { Segment segment = 1; } @@ -18,7 +47,7 @@ message GetSegmentsRequest { optional string type = 2; optional SegmentScope scope = 3; optional string topic = 4; - optional string collection = 5; + optional string collection = 5; // Collection ID } message GetSegmentsResponse { @@ -49,6 +78,8 @@ message CreateCollectionRequest { optional UpdateMetadata metadata = 3; optional int32 dimension = 4; optional bool get_or_create = 5; + string tenant = 6; + string database = 7; } message CreateCollectionResponse { @@ -59,12 +90,16 @@ message CreateCollectionResponse { message DeleteCollectionRequest { string id = 1; + string tenant = 2; + string database = 3; } message GetCollectionsRequest { optional string id = 1; optional string name = 2; optional string topic = 3; + string tenant = 4; + string database = 5; } message GetCollectionsResponse { @@ -84,6 +119,10 @@ message UpdateCollectionRequest { } service SysDB { + rpc CreateDatabase(CreateDatabaseRequest) returns (ChromaResponse) {} + rpc GetDatabase(GetDatabaseRequest) returns (GetDatabaseResponse) {} + rpc CreateTenant(CreateTenantRequest) returns (ChromaResponse) {} + rpc GetTenant(GetTenantRequest) returns (GetTenantResponse) {} rpc CreateSegment(CreateSegmentRequest) returns (ChromaResponse) {} rpc DeleteSegment(DeleteSegmentRequest) returns (ChromaResponse) {} rpc GetSegments(GetSegmentsRequest) returns (GetSegmentsResponse) {} diff --git a/requirements.txt b/requirements.txt index f3093341f14..4ad9f29aeb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ bcrypt==4.0.1 chroma-hnswlib==0.7.3 fastapi>=0.95.2 graphlib_backport==1.0.3; python_version < '3.9' -grpcio==1.58.0 +grpcio>=1.58.0 importlib-resources kubernetes>=28.1.0 numpy==1.21.6; python_version < '3.8'