From e5648601a46ca1da1d50ff0ee71b30891ddd6cca Mon Sep 17 00:00:00 2001 From: hammadb Date: Tue, 17 Oct 2023 16:53:08 -0700 Subject: [PATCH 01/22] [CLN] Remove support for 3.7, add support for 3.11 --- .github/workflows/chroma-client-integration-test.yml | 4 ++-- .github/workflows/chroma-cluster-test.yml | 2 +- .github/workflows/chroma-integration-test.yml | 2 +- .github/workflows/chroma-test.yml | 4 ++-- clients/python/pyproject.toml | 4 ++-- pyproject.toml | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/chroma-client-integration-test.yml b/.github/workflows/chroma-client-integration-test.yml index 25788090ef2..5724959c254 100644 --- a/.github/workflows/chroma-client-integration-test.yml +++ b/.github/workflows/chroma-client-integration-test.yml @@ -9,13 +9,13 @@ on: - main - '**' workflow_dispatch: - + jobs: test: timeout-minutes: 90 strategy: matrix: - python: ['3.7', '3.8', '3.9', '3.10'] + python: ['3.8', '3.9', '3.10', '3.11'] platform: [ubuntu-latest, windows-latest] runs-on: ${{ matrix.platform }} steps: diff --git a/.github/workflows/chroma-cluster-test.yml b/.github/workflows/chroma-cluster-test.yml index fc8e514f323..422ce9190d0 100644 --- a/.github/workflows/chroma-cluster-test.yml +++ b/.github/workflows/chroma-cluster-test.yml @@ -14,7 +14,7 @@ jobs: test: strategy: matrix: - python: ['3.7'] + python: ['3.8'] platform: [ubuntu-latest] testfile: ["chromadb/test/ingest/test_producer_consumer.py", "chromadb/test/segment/distributed/test_memberlist_provider.py",] diff --git a/.github/workflows/chroma-integration-test.yml b/.github/workflows/chroma-integration-test.yml index 963a7b6ed63..4d9c0085d18 100644 --- a/.github/workflows/chroma-integration-test.yml +++ b/.github/workflows/chroma-integration-test.yml @@ -15,7 +15,7 @@ jobs: test: strategy: matrix: - python: ['3.7'] + python: ['3.8'] platform: [ubuntu-latest, windows-latest] testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py'", "chromadb/test/property/test_add.py", diff --git a/.github/workflows/chroma-test.yml b/.github/workflows/chroma-test.yml index 90ff2b66940..4ef9c64ed7b 100644 --- a/.github/workflows/chroma-test.yml +++ b/.github/workflows/chroma-test.yml @@ -16,7 +16,7 @@ jobs: timeout-minutes: 90 strategy: matrix: - python: ['3.7', '3.8', '3.9', '3.10'] + python: ['3.8', '3.9', '3.10', '3.11'] platform: [ubuntu-latest, windows-latest] testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore-glob 'chromadb/test/stress/*'", "chromadb/test/property/test_add.py", @@ -44,7 +44,7 @@ jobs: timeout-minutes: 90 strategy: matrix: - python: ['3.7'] + python: ['3.8'] platform: ['16core-64gb-ubuntu-latest', '16core-64gb-windows-latest'] testfile: ["'chromadb/test/stress/'"] runs-on: ${{ matrix.platform }} diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 3afff14f4be..a276becb196 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -8,7 +8,7 @@ authors = [ ] description = "Chroma Client." readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.8" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", @@ -26,7 +26,7 @@ dependencies = [ [tool.black] line-length = 88 required-version = "23.3.0" # Black will refuse to run if it's not this version. -target-version = ['py36', 'py37', 'py38', 'py39', 'py310'] +target-version = ['py38', 'py39', 'py310', 'py311'] [tool.pytest.ini_options] pythonpath = ["."] diff --git a/pyproject.toml b/pyproject.toml index 7dd144cf3ab..5e7ca2ea247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ authors = [ ] description = "Chroma." readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.8" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", @@ -42,7 +42,7 @@ dependencies = [ [tool.black] line-length = 88 required-version = "23.3.0" # Black will refuse to run if it's not this version. -target-version = ['py36', 'py37', 'py38', 'py39', 'py310'] +target-version = ['py38', 'py39', 'py310', 'py311'] [tool.pytest.ini_options] pythonpath = ["."] From 0321fe8edca4e05f3d9fac0c66f3e05814811818 Mon Sep 17 00:00:00 2001 From: hammadb Date: Fri, 13 Oct 2023 00:25:58 -0700 Subject: [PATCH 02/22] Rename namespace to database --- chromadb/__init__.py | 79 +++- chromadb/api/__init__.py | 222 +++++++++- chromadb/api/client.py | 382 ++++++++++++++++++ chromadb/api/fastapi.py | 52 ++- chromadb/api/models/Collection.py | 8 +- chromadb/api/segment.py | 54 ++- chromadb/config.py | 6 +- .../sysdb/00004-tenants-databases.sqlite.sql | 19 + chromadb/server/fastapi/__init__.py | 8 +- chromadb/test/auth/test_token_auth.py | 8 +- chromadb/test/conftest.py | 28 +- chromadb/test/property/test_add.py | 12 +- chromadb/test/property/test_collections.py | 6 +- .../property/test_cross_version_persist.py | 16 +- chromadb/test/property/test_embeddings.py | 20 +- chromadb/test/property/test_filtering.py | 10 +- chromadb/test/property/test_persist.py | 12 +- chromadb/test/stress/test_many_collections.py | 4 +- chromadb/test/test_api.py | 8 +- chromadb/test/test_chroma.py | 15 +- chromadb/test/test_client.py | 27 +- chromadb/test/test_multithreaded.py | 12 +- chromadb/utils/batch_utils.py | 4 +- 23 files changed, 889 insertions(+), 123 deletions(-) create mode 100644 chromadb/api/client.py create mode 100644 chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 9c0b8000a14..8360a69498c 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -1,8 +1,9 @@ from typing import Dict import logging +from chromadb.api.client import Client as ClientCreator import chromadb.config -from chromadb.config import Settings, System -from chromadb.api import API +from chromadb.config import Settings +from chromadb.api import ClientAPI from chromadb.api.models.Collection import Collection from chromadb.api.types import ( CollectionMetadata, @@ -35,8 +36,6 @@ "QueryResult", "GetResult", ] -from chromadb.telemetry.events import ClientStartEvent -from chromadb.telemetry import Telemetry logger = logging.getLogger(__name__) @@ -55,13 +54,15 @@ is_client = False try: - from chromadb.is_thin_client import is_thin_client # type: ignore + from chromadb.is_thin_client import is_thin_client + is_client = is_thin_client except ImportError: is_client = False if not is_client: import sqlite3 + if sqlite3.sqlite_version_info < (3, 35, 0): if IN_COLAB: # In Colab, hotswap to pysqlite-binary if it's too old @@ -90,7 +91,7 @@ def get_settings() -> Settings: return __settings -def EphemeralClient(settings: Settings = Settings()) -> API: +def EphemeralClient(settings: Settings = Settings()) -> ClientAPI: """ Creates an in-memory instance of Chroma. This is useful for testing and development, but not recommended for production use. @@ -100,7 +101,12 @@ def EphemeralClient(settings: Settings = Settings()) -> API: return Client(settings) -def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) -> API: +def PersistentClient( + path: str = "./chroma", + tenant: str = "default", + database: str = "default", + settings: Settings = Settings(), +) -> ClientAPI: """ Creates a persistent instance of Chroma that saves to disk. This is useful for testing and development, but not recommended for production use. @@ -111,7 +117,7 @@ def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) -> settings.persist_directory = path settings.is_persistent = True - return Client(settings) + return ClientCreator(tenant=tenant, database=database, settings=settings) def HttpClient( @@ -119,8 +125,10 @@ def HttpClient( port: str = "8000", ssl: bool = False, headers: Dict[str, str] = {}, + tenant: str = "default", + database: str = "default", settings: Settings = Settings(), -) -> API: +) -> ClientAPI: """ Creates a client that connects to a remote Chroma server. This supports many clients connecting to the same server, and is the recommended way to @@ -139,20 +147,47 @@ def HttpClient( settings.chroma_server_ssl_enabled = ssl settings.chroma_server_headers = headers - return Client(settings) + return ClientCreator(tenant=tenant, database=database, settings=settings) -def Client(settings: Settings = __settings) -> API: +# TODO: replace default tenant and database strings with constants +def Client( + settings: Settings = __settings, tenant: str = "default", database: str = "default" +) -> ClientAPI: """Return a running chroma.API instance""" - system = System(settings) - - telemetry_client = system.instance(Telemetry) - api = system.instance(API) - - system.start() - - # Submit event for client start - telemetry_client.capture(ClientStartEvent()) - - return api + # Change this to actually check if an "API" instance already exists, wrap it in a + # tenant/database aware "Client", and return it + # this way we can support multiple clients in the same process but using the same + # chroma instance + + # API is thread safe, so we can just return the same instance + # This way a "Client" will just be a wrapper around an API instance that is + # tenant/database aware + + # To do this we will + # 1. Have a global dict of API instances, keyed by path + # 2. When a client is requested, check if one exists in the dict, and if so check if its + # settings match the requested settings + # 3. If the settings match, construct a new Client that wraps the existing API instance with + # the tenant/database + # 4. If the settings don't match, error out because we don't support changing the settings + # got a given database + # 5. If no client exists in the dict, create a new API instance, wrap it in a Client, and + # add it to the dict + + # The hierarchy then becomes + # For local + # Path -> Tenant -> Namespace -> API + # For remote + # Host -> Tenant -> Namespace -> API + + # A given API for a path is a singleton, and is shared between all tenants and namespaces + # for that path + + # A DB exists at a path or host, and has tenants and namespaces + + # All our tests currently use system.instance(API) assuming thats the root object + # This is likely fine, + + return ClientCreator(tenant=tenant, database=database, settings=settings) diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 50f2ff1ecef..2287bd929d2 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from typing import Sequence, Optional from uuid import UUID + +from overrides import override +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT from chromadb.api.models.Collection import Collection from chromadb.api.types import ( CollectionMetadata, @@ -19,7 +22,7 @@ import chromadb.utils.embedding_functions as ef -class API(Component, ABC): +class BaseAPI(ABC): @abstractmethod def heartbeat(self) -> int: """Get the current time in nanoseconds since epoch. @@ -371,10 +374,10 @@ def get_version(self) -> str: @abstractmethod def get_settings(self) -> Settings: - """Get the settings used to initialize the client. + """Get the settings used to initialize. Returns: - Settings: The settings used to initialize the client. + Settings: The settings used to initialize. """ pass @@ -385,3 +388,216 @@ def max_batch_size(self) -> int: """Return the maximum number of records that can be submitted in a single call to submit_embeddings.""" pass + + +class ClientAPI(BaseAPI, ABC): + @abstractmethod + def set_database(self, database: str) -> None: + """Set the database for the client. + + Args: + database: The database to set. + + """ + pass + + @abstractmethod + def set_tenant(self, tenant: str) -> None: + """Set the tenant for the client. + + Args: + tenant: The tenant to set. + + """ + pass + + @staticmethod + @abstractmethod + def clear_system_cache() -> None: + """Clear the system cache so that new systems can be created for an existing path. + This should only be used for testing purposes.""" + pass + + +class ServerAPI(BaseAPI, Component): + """An API instance that extends the relevant Base API methods by passing + in a tenant and database. This is the root component of the Chroma System""" + + @abstractmethod + @override + def list_collections( + self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> Sequence[Collection]: + pass + + @abstractmethod + @override + def create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + pass + + @abstractmethod + @override + def get_collection( + self, + name: str, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + pass + + @abstractmethod + @override + def get_or_create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + pass + + @abstractmethod + @override + def _modify( + self, + id: UUID, + new_name: Optional[str] = None, + new_metadata: Optional[CollectionMetadata] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: + pass + + @abstractmethod + @override + def delete_collection( + self, + name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: + pass + + # + # ITEM METHODS + # + + @abstractmethod + @override + def _add( + self, + ids: IDs, + collection_id: UUID, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + pass + + @abstractmethod + @override + def _update( + self, + collection_id: UUID, + ids: IDs, + embeddings: Optional[Embeddings] = None, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + pass + + @abstractmethod + @override + def _upsert( + self, + collection_id: UUID, + ids: IDs, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + pass + + @abstractmethod + @override + def _count( + self, + collection_id: UUID, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> int: + pass + + @abstractmethod + @override + def _peek( + self, + collection_id: UUID, + n: int = 10, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> GetResult: + pass + + @abstractmethod + @override + def _get( + self, + collection_id: UUID, + ids: Optional[IDs] = None, + where: Optional[Where] = {}, + sort: Optional[str] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + page: Optional[int] = None, + page_size: Optional[int] = None, + where_document: Optional[WhereDocument] = {}, + include: Include = ["embeddings", "metadatas", "documents"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> GetResult: + pass + + @abstractmethod + @override + def _delete( + self, + collection_id: UUID, + ids: Optional[IDs], + where: Optional[Where] = {}, + where_document: Optional[WhereDocument] = {}, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> IDs: + pass + + @abstractmethod + @override + def _query( + self, + collection_id: UUID, + query_embeddings: Embeddings, + n_results: int = 10, + where: Where = {}, + where_document: WhereDocument = {}, + include: Include = ["embeddings", "metadatas", "documents", "distances"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> QueryResult: + pass diff --git a/chromadb/api/client.py b/chromadb/api/client.py new file mode 100644 index 00000000000..c007495705b --- /dev/null +++ b/chromadb/api/client.py @@ -0,0 +1,382 @@ +from typing import ClassVar, Dict, Optional, Sequence +from uuid import UUID + +from overrides import override +from chromadb.api import ClientAPI, ServerAPI +from chromadb.api.types import ( + CollectionMetadata, + Documents, + EmbeddingFunction, + Embeddings, + GetResult, + IDs, + Include, + Metadatas, + QueryResult, +) +from chromadb.config import Settings, System +from chromadb.telemetry import Telemetry +from chromadb.telemetry.events import ClientStartEvent +from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE +from chromadb.api.models.Collection import Collection +from chromadb.types import Where, WhereDocument +import chromadb.utils.embedding_functions as ef + + +class Client(ClientAPI): + """A client for Chroma. This is the main entrypoint for interacting with Chroma. + A client internally stores its tenant and database and proxies calls to a + Server API instance of Chroma. It treats the Server API and corresponding System + as a singleton, so multiple clients connecting to the same resource will share the + same API instance. + + Client implementations should be implement their own API-caching strategies. + """ + + tenant: str = DEFAULT_TENANT + database: str = DEFAULT_DATABASE + + _identifer_to_system: ClassVar[Dict[str, System]] = {} + _identifier: str + _server: ServerAPI + + # region Initialization + def __new__( + cls, + tenant: str = "default", + database: str = "default", + settings: Settings = Settings(), + ) -> "Client": + identifier = cls._get_identifier_from_settings(settings) + cls._create_system_if_not_exists(identifier, settings) + instance = super().__new__(cls) + return instance + + def __init__( + self, + tenant: str = "default", + database: str = "default", + settings: Settings = Settings(), + ) -> None: + self.tenant = tenant + self.database = database + self._identifier = self._get_identifier_from_settings(settings) + + # Get the root system component we want to interact with + self._server = self._system.instance(ServerAPI) + + # Submit event for a client start + telemetry_client = self._system.instance(Telemetry) + telemetry_client.capture(ClientStartEvent()) + + @classmethod + def _create_system_if_not_exists( + cls, identifier: str, settings: Settings + ) -> System: + if identifier not in cls._identifer_to_system: + new_system = System(settings) + cls._identifer_to_system[identifier] = new_system + + new_system.instance(Telemetry) + new_system.instance(ServerAPI) + + new_system.start() + else: + previous_system = cls._identifer_to_system[identifier] + + # For now, the settings must match + if previous_system.settings != settings: + raise ValueError( + f"An instance of Chroma already exists for {identifier} with different settings" + ) + + return cls._identifer_to_system[identifier] + + @staticmethod + def _get_identifier_from_settings(settings: Settings) -> str: + identifier = "" + api_impl = settings.chroma_api_impl + + if api_impl is None: + raise ValueError("Chroma API implementation must be set in settings") + elif api_impl == "chromadb.api.segment.SegmentAPI": + if settings.is_persistent: + identifier = settings.persist_directory + else: + identifier = ( + "ephemeral" # TODO: support pathing and multiple ephemeral clients + ) + elif api_impl == "chromadb.api.fastapi.FastAPI": + identifier = ( + f"{settings.chroma_server_host}:{settings.chroma_server_http_port}" + ) + else: + raise ValueError(f"Unsupported Chroma API implementation {api_impl}") + + return identifier + + @staticmethod + @override + def clear_system_cache() -> None: + Client._identifer_to_system = {} + + @property + def _system(self) -> System: + return self._identifer_to_system[self._identifier] + + # endregion + + # region BaseAPI Methods + # Note - we could do this in less verbose ways, but they break type checking + @override + def heartbeat(self) -> int: + return self._server.heartbeat() + + @override + def list_collections(self) -> Sequence[Collection]: + return self._server.list_collections(tenant=self.tenant, database=self.database) + + @override + def create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + get_or_create: bool = False, + ) -> Collection: + return self._server.create_collection( + name=name, + metadata=metadata, + embedding_function=embedding_function, + tenant=self.tenant, + database=self.database, + ) + + @override + def get_collection( + self, + name: str, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + ) -> Collection: + return self._server.get_collection( + name=name, + embedding_function=embedding_function, + tenant=self.tenant, + database=self.database, + ) + + @override + def get_or_create_collection( + self, + name: str, + metadata: Optional[CollectionMetadata] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + ) -> Collection: + return self._server.get_or_create_collection( + name=name, + metadata=metadata, + embedding_function=embedding_function, + tenant=self.tenant, + database=self.database, + ) + + @override + def _modify( + self, + id: UUID, + new_name: Optional[str] = None, + new_metadata: Optional[CollectionMetadata] = None, + ) -> None: + return self._server._modify( + id=id, + new_name=new_name, + new_metadata=new_metadata, + tenant=self.tenant, + database=self.database, + ) + + @override + def delete_collection( + self, + name: str, + ) -> None: + return self._server.delete_collection( + name=name, + tenant=self.tenant, + database=self.database, + ) + + # + # ITEM METHODS + # + + @override + def _add( + self, + ids: IDs, + collection_id: UUID, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + ) -> bool: + return self._server._add( + ids=ids, + collection_id=collection_id, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + tenant=self.tenant, + database=self.database, + ) + + @override + def _update( + self, + collection_id: UUID, + ids: IDs, + embeddings: Optional[Embeddings] = None, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + ) -> bool: + return self._server._update( + collection_id=collection_id, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + tenant=self.tenant, + database=self.database, + ) + + @override + def _upsert( + self, + collection_id: UUID, + ids: IDs, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + ) -> bool: + return self._server._upsert( + collection_id=collection_id, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + tenant=self.tenant, + database=self.database, + ) + + @override + def _count(self, collection_id: UUID) -> int: + return self._server._count( + collection_id=collection_id, + tenant=self.tenant, + database=self.database, + ) + + @override + def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: + return self._server._peek( + collection_id=collection_id, + n=n, + tenant=self.tenant, + database=self.database, + ) + + @override + def _get( + self, + collection_id: UUID, + ids: Optional[IDs] = None, + where: Optional[Where] = {}, + sort: Optional[str] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + page: Optional[int] = None, + page_size: Optional[int] = None, + where_document: Optional[WhereDocument] = {}, + include: Include = ["embeddings", "metadatas", "documents"], + ) -> GetResult: + return self._server._get( + collection_id=collection_id, + ids=ids, + where=where, + sort=sort, + limit=limit, + offset=offset, + page=page, + page_size=page_size, + where_document=where_document, + include=include, + tenant=self.tenant, + database=self.database, + ) + + def _delete( + self, + collection_id: UUID, + ids: Optional[IDs], + where: Optional[Where] = {}, + where_document: Optional[WhereDocument] = {}, + ) -> IDs: + return self._server._delete( + collection_id=collection_id, + ids=ids, + where=where, + where_document=where_document, + tenant=self.tenant, + database=self.database, + ) + + @override + def _query( + self, + collection_id: UUID, + query_embeddings: Embeddings, + n_results: int = 10, + where: Where = {}, + where_document: WhereDocument = {}, + include: Include = ["embeddings", "metadatas", "documents", "distances"], + ) -> QueryResult: + return self._server._query( + collection_id=collection_id, + query_embeddings=query_embeddings, + n_results=n_results, + where=where, + where_document=where_document, + include=include, + tenant=self.tenant, + database=self.database, + ) + + @override + def reset(self) -> bool: + return self._server.reset() + + @override + def get_version(self) -> str: + return self._server.get_version() + + @override + def get_settings(self) -> Settings: + return self._server.get_settings() + + @property + @override + def max_batch_size(self) -> int: + return self._server.max_batch_size + + # endregion + + # region ClientAPI Methods + + @override + def set_database(self, database: str) -> None: + self.database = database + + @override + def set_tenant(self, tenant: str) -> None: + self.tenant = tenant + + # endregion diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 2ddd537ebff..2ee75612c23 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -9,7 +9,7 @@ import chromadb.errors as errors import chromadb.utils.embedding_functions as ef -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection from chromadb.api.types import ( Documents, @@ -30,14 +30,14 @@ ) from chromadb.auth.providers import RequestsClientAuthProtocolAdapter from chromadb.auth.registry import resolve_provider -from chromadb.config import Settings, System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.telemetry import Telemetry from urllib.parse import urlparse, urlunparse, quote logger = logging.getLogger(__name__) -class FastAPI(API): +class FastAPI(ServerAPI): _settings: Settings _max_batch_size: int = -1 @@ -135,7 +135,9 @@ def heartbeat(self) -> int: return int(resp.json()["nanosecond heartbeat"]) @override - def list_collections(self) -> Sequence[Collection]: + def list_collections( + self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> Sequence[Collection]: """Returns a list of all collections""" resp = self._session.get(self._api_url + "/collections") raise_chroma_error(resp) @@ -153,6 +155,8 @@ def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: """Creates a collection""" resp = self._session.post( @@ -176,6 +180,8 @@ def get_collection( self, name: str, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: """Returns a collection""" resp = self._session.get(self._api_url + "/collections/" + name) @@ -195,6 +201,8 @@ def get_or_create_collection( name: str, metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: return self.create_collection( name, metadata, embedding_function, get_or_create=True @@ -206,6 +214,8 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> None: """Updates a collection""" resp = self._session.put( @@ -215,13 +225,23 @@ def _modify( raise_chroma_error(resp) @override - def delete_collection(self, name: str) -> None: + def delete_collection( + self, + name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: """Deletes a collection""" resp = self._session.delete(self._api_url + "/collections/" + name) raise_chroma_error(resp) @override - def _count(self, collection_id: UUID) -> int: + def _count( + self, + collection_id: UUID, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> int: """Returns the number of embeddings in the database""" resp = self._session.get( self._api_url + "/collections/" + str(collection_id) + "/count" @@ -230,7 +250,13 @@ def _count(self, collection_id: UUID) -> int: return cast(int, resp.json()) @override - def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: + def _peek( + self, + collection_id: UUID, + n: int = 10, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> GetResult: return self._get( collection_id, limit=n, @@ -250,6 +276,8 @@ def _get( page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> GetResult: if page and page_size: offset = (page - 1) * page_size @@ -286,6 +314,8 @@ def _delete( ids: Optional[IDs] = None, where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> IDs: """Deletes embeddings from the database""" resp = self._session.post( @@ -329,6 +359,8 @@ def _add( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: """ Adds a batch of embeddings to the database @@ -348,6 +380,8 @@ def _update( embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: """ Updates a batch of embeddings in the database @@ -369,6 +403,8 @@ def _upsert( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: """ Upserts a batch of embeddings in the database @@ -391,6 +427,8 @@ def _query( where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents", "distances"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> QueryResult: """Gets the nearest neighbors of a single embedding""" resp = self._session.post( diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index c11a04b1fa4..d1f4b296712 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -33,19 +33,21 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from chromadb.api import API + from chromadb.api import ServerAPI class Collection(BaseModel): name: str id: UUID metadata: Optional[CollectionMetadata] = None - _client: "API" = PrivateAttr() + _client: "ServerAPI" = PrivateAttr() _embedding_function: Optional[EmbeddingFunction] = PrivateAttr() + # TODO: STORE THE TENANT AND NAMESPACE IN THE COLLECTION OBJECT + def __init__( self, - client: "API", + client: "ServerAPI", name: str, id: UUID, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index cfe1300e76e..27b53ea41ef 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -1,5 +1,5 @@ -from chromadb.api import API -from chromadb.config import Settings, System +from chromadb.api import ServerAPI +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.db.system import SysDB from chromadb.segment import SegmentManager, MetadataReader, VectorReader from chromadb.telemetry import Telemetry @@ -71,7 +71,7 @@ def check_index_name(index_name: str) -> None: raise ValueError(msg) -class SegmentAPI(API): +class SegmentAPI(ServerAPI): """API implementation utilizing the new segment-based internal architecture""" _settings: Settings @@ -104,6 +104,8 @@ def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: if metadata is not None: validate_metadata(metadata) @@ -148,6 +150,8 @@ def get_or_create_collection( name: str, metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: return self.create_collection( name=name, @@ -164,6 +168,8 @@ def get_collection( self, name: str, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Collection: existing = self._sysdb.get_collections(name=name) @@ -179,7 +185,11 @@ def get_collection( raise ValueError(f"Collection {name} does not exist.") @override - def list_collections(self) -> Sequence[Collection]: + def list_collections( + self, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Sequence[Collection]: collections = [] db_collections = self._sysdb.get_collections() for db_collection in db_collections: @@ -199,6 +209,8 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> None: if new_name: # backwards compatibility in naming requirements (for now) @@ -217,7 +229,12 @@ def _modify( self._sysdb.update_collection(id, metadata=new_metadata) @override - def delete_collection(self, name: str) -> None: + def delete_collection( + self, + name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: existing = self._sysdb.get_collections(name=name) if existing: @@ -237,6 +254,8 @@ def _add( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) @@ -274,6 +293,8 @@ def _update( embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) @@ -313,6 +334,8 @@ def _upsert( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) @@ -347,6 +370,8 @@ def _get( page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> GetResult: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( @@ -414,6 +439,8 @@ def _delete( ids: Optional[IDs] = None, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> IDs: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( @@ -469,7 +496,12 @@ def _delete( return ids_to_delete @override - def _count(self, collection_id: UUID) -> int: + def _count( + self, + collection_id: UUID, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> int: metadata_segment = self._manager.get_segment(collection_id, MetadataReader) return metadata_segment.count() @@ -482,6 +514,8 @@ def _query( where: Where = {}, where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> QueryResult: where = validate_where(where) if where is not None and len(where) > 0 else where where_document = ( @@ -574,7 +608,13 @@ def _query( ) @override - def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: + def _peek( + self, + collection_id: UUID, + n: int = 10, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> GetResult: return self._get(collection_id, limit=n) @override diff --git a/chromadb/config.py b/chromadb/config.py index eb7bca93ef5..eb575696049 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -63,7 +63,8 @@ # TODO: Don't use concrete types here to avoid circular deps. Strings are fine for right here! _abstract_type_keys: Dict[str, str] = { - "chromadb.api.API": "chroma_api_impl", + "chromadb.api.API": "chroma_api_impl", # NOTE: this is to support legacy api construction. Use ServerAPI instead + "chromadb.api.ServerAPI": "chroma_api_impl", "chromadb.telemetry.Telemetry": "chroma_telemetry_impl", "chromadb.ingest.Producer": "chroma_producer_impl", "chromadb.ingest.Consumer": "chroma_consumer_impl", @@ -74,6 +75,9 @@ "chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl", } +DEFAULT_TENANT = "default" +DEFAULT_DATABASE = "default" + class Settings(BaseSettings): # type: ignore environment: str = "" diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql new file mode 100644 index 00000000000..1c40e823480 --- /dev/null +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -0,0 +1,19 @@ +CREATE TABLE tenants ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + UNIQUE (name) -- Maybe not needed since we want to support slug ids +); + +CREATE TABLE databases ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + tenant_id TEXT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, + UNIQUE (name) +); + +ALTER TABLE collections + ADD COLUMN database_id TEXT NOT NULL REFERENCES databases(id) DEFAULT 'default'; -- ON DELETE CASCADE not supported by sqlite in ALTER TABLE + +-- Create default tenant and database +INSERT INTO tenants (id, name) VALUES ('default', 'default'); +INSERT INTO databases (id, name, tenant_id) VALUES ('default', 'default', 'default'); diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index e92d16d63ba..ec8bcac7ea1 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -15,9 +15,10 @@ FastAPIChromaAuthMiddleware, FastAPIChromaAuthMiddlewareWrapper, ) -from chromadb.config import Settings +from chromadb.config import Settings, System import chromadb.server import chromadb.api +from chromadb.api import ServerAPI from chromadb.errors import ( ChromaError, InvalidUUIDError, @@ -99,12 +100,15 @@ def include_in_schema(path: str) -> bool: super().add_api_route(path, *args, **kwargs) +# TODO: add tenant/namespace to all routes class FastAPI(chromadb.server.Server): def __init__(self, settings: Settings): super().__init__(settings) Telemetry.SERVER_CONTEXT = ServerContext.FASTAPI self._app = fastapi.FastAPI(debug=True) - self._api: chromadb.api.API = chromadb.Client(settings) + self._system = System(settings) + self._api: ServerAPI = self._system.instance(ServerAPI) + self._system.start() self._app.middleware("http")(catch_exceptions_middleware) self._app.add_middleware( diff --git a/chromadb/test/auth/test_token_auth.py b/chromadb/test/auth/test_token_auth.py index 4e99baae306..50e88e296a9 100644 --- a/chromadb/test/auth/test_token_auth.py +++ b/chromadb/test/auth/test_token_auth.py @@ -5,7 +5,7 @@ import pytest from hypothesis import given, settings -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.config import System from chromadb.test.conftest import _fastapi_fixture @@ -64,7 +64,7 @@ def test_fastapi_server_token_auth(token_config: Dict[str, Any]) -> None: ) _sys: System = next(api) _sys.reset_state() - _api = _sys.instance(API) + _api = _sys.instance(ServerAPI) _api.heartbeat() assert _api.list_collections() == [] @@ -103,7 +103,7 @@ def test_invalid_token(tconf: Dict[str, Any], inval_tok: str) -> None: with pytest.raises(Exception) as e: _sys: System = next(api) _sys.reset_state() - _sys.instance(API) + _sys.instance(ServerAPI) assert "Invalid token" in str(e) @@ -131,7 +131,7 @@ def test_fastapi_server_token_auth_wrong_token( ) _sys: System = next(api) _sys.reset_state() - _api = _sys.instance(API) + _api = _sys.instance(ServerAPI) _api.heartbeat() with pytest.raises(Exception) as e: _api.list_collections() diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index af66ef2513f..238748bcbbb 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -22,7 +22,7 @@ from typing_extensions import Protocol import chromadb.server.fastapi -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.config import Settings, System from chromadb.db.mixins import embeddings_queue from chromadb.ingest import Producer @@ -98,7 +98,7 @@ def _run_server( uvicorn.run(server.app(), host="0.0.0.0", port=port, log_level="error") -def _await_server(api: API, attempts: int = 0) -> None: +def _await_server(api: ServerAPI, attempts: int = 0) -> None: try: api.heartbeat() except ConnectionError as e: @@ -172,7 +172,7 @@ def _fastapi_fixture( chroma_client_auth_token_transport_header=chroma_client_auth_token_transport_header, ) system = System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) system.start() _await_server(api) yield system @@ -198,7 +198,7 @@ def basic_http_client() -> Generator[System, None, None]: allow_reset=True, ) system = System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) _await_server(api) system.start() yield system @@ -361,41 +361,43 @@ def system_fixtures_wrong_auth() -> List[Callable[[], Generator[System, None, No @pytest.fixture(scope="module", params=system_fixtures_wrong_auth()) -def system_wrong_auth(request: pytest.FixtureRequest) -> Generator[API, None, None]: +def system_wrong_auth( + request: pytest.FixtureRequest, +) -> Generator[ServerAPI, None, None]: yield next(request.param()) @pytest.fixture(scope="module", params=system_fixtures()) -def system(request: pytest.FixtureRequest) -> Generator[API, None, None]: +def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: yield next(request.param()) @pytest.fixture(scope="module", params=system_fixtures_auth()) -def system_auth(request: pytest.FixtureRequest) -> Generator[API, None, None]: +def system_auth(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: yield next(request.param()) @pytest.fixture(scope="function") -def api(system: System) -> Generator[API, None, None]: +def api(system: System) -> Generator[ServerAPI, None, None]: system.reset_state() - api = system.instance(API) + api = system.instance(ServerAPI) yield api @pytest.fixture(scope="function") def api_wrong_cred( system_wrong_auth: System, -) -> Generator[API, None, None]: +) -> Generator[ServerAPI, None, None]: system_wrong_auth.reset_state() - api = system_wrong_auth.instance(API) + api = system_wrong_auth.instance(ServerAPI) yield api @pytest.fixture(scope="function") -def api_with_server_auth(system_auth: System) -> Generator[API, None, None]: +def api_with_server_auth(system_auth: System) -> Generator[ServerAPI, None, None]: _sys = system_auth _sys.reset_state() - api = _sys.instance(API) + api = _sys.instance(ServerAPI) yield api diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index 1980ed2a9d9..5f8991b00ed 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -5,7 +5,7 @@ import pytest import hypothesis.strategies as st from hypothesis import given, settings -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.types import Embeddings, Metadatas import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants @@ -17,7 +17,7 @@ @given(collection=collection_st, record_set=strategies.recordsets(collection_st)) @settings(deadline=None) def test_add( - api: API, + api: ServerAPI, collection: strategies.Collection, record_set: strategies.RecordSet, ) -> None: @@ -69,7 +69,7 @@ def create_large_recordset( @given(collection=collection_st) @settings(deadline=None, max_examples=1) -def test_add_large(api: API, collection: strategies.Collection) -> None: +def test_add_large(api: ServerAPI, collection: strategies.Collection) -> None: api.reset() record_set = create_large_recordset( min_size=api.max_batch_size, @@ -99,7 +99,7 @@ def test_add_large(api: API, collection: strategies.Collection) -> None: @given(collection=collection_st) @settings(deadline=None, max_examples=1) -def test_add_large_exceeding(api: API, collection: strategies.Collection) -> None: +def test_add_large_exceeding(api: ServerAPI, collection: strategies.Collection) -> None: api.reset() record_set = create_large_recordset( min_size=api.max_batch_size, @@ -126,7 +126,7 @@ def test_add_large_exceeding(api: API, collection: strategies.Collection) -> Non reason="This is expected to fail right now. We should change the API to sort the \ ids by input order." ) -def test_out_of_order_ids(api: API) -> None: +def test_out_of_order_ids(api: ServerAPI) -> None: api.reset() ooo_ids = [ "40", @@ -165,7 +165,7 @@ def test_out_of_order_ids(api: API) -> None: assert get_ids == ooo_ids -def test_add_partial(api: API) -> None: +def test_add_partial(api: ServerAPI) -> None: """Tests adding a record set with some of the fields set to None.""" api.reset() diff --git a/chromadb/test/property/test_collections.py b/chromadb/test/property/test_collections.py index 60e3de7592c..2d41c62ef80 100644 --- a/chromadb/test/property/test_collections.py +++ b/chromadb/test/property/test_collections.py @@ -2,7 +2,7 @@ import logging import hypothesis.strategies as st import chromadb.test.property.strategies as strategies -from chromadb.api import API +from chromadb.api import ClientAPI import chromadb.api.types as types from hypothesis.stateful import ( Bundle, @@ -23,7 +23,7 @@ class CollectionStateMachine(RuleBasedStateMachine): collections = Bundle("collections") - def __init__(self, api: API): + def __init__(self, api: ClientAPI): super().__init__() self.model = {} self.api = api @@ -203,6 +203,6 @@ def modify_coll( return multiple(coll) -def test_collections(caplog: pytest.LogCaptureFixture, api: API) -> None: +def test_collections(caplog: pytest.LogCaptureFixture, api: ClientAPI) -> None: caplog.set_level(logging.ERROR) run_state_machine_as_test(lambda: CollectionStateMachine(api)) # type: ignore diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 529fe02dda7..11780d4d675 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -5,14 +5,14 @@ import subprocess import tempfile from types import ModuleType -from typing import Generator, List, Tuple, Dict, Any, Callable +from typing import Generator, List, Tuple, Dict, Any, Callable, Type from hypothesis import given, settings import hypothesis.strategies as st import pytest import json from urllib import request from chromadb import config -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.types import Documents, EmbeddingFunction, Embeddings import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants @@ -84,6 +84,12 @@ def patch_for_version( patch(collection, embeddings) +def api_import_for_version(module: Any, version: str) -> Type: # type: ignore + if packaging_version.Version(version) <= packaging_version.Version("0.4.14"): + return module.api.API # type: ignore + return module.api.ServerAPI # type: ignore + + def configurations(versions: List[str]) -> List[Tuple[str, Settings]]: return [ ( @@ -197,13 +203,13 @@ def persist_generated_data_with_old_version( try: old_module = switch_to_version(version) system = old_module.config.System(settings) - api: API = system.instance(API) + api = system.instance(api_import_for_version(old_module, version)) system.start() api.reset() coll = api.create_collection( name=collection_strategy.name, - metadata=collection_strategy.metadata, # type: ignore + metadata=collection_strategy.metadata, # In order to test old versions, we can't rely on the not_implemented function embedding_function=not_implemented_ef(), ) @@ -288,7 +294,7 @@ def test_cycle_versions( # Switch to the current version (local working directory) and check the invariants # are preserved for the collection system = config.System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) system.start() coll = api.get_collection( name=collection_strategy.name, diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 0e402cca1a8..7fc2491c14b 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from chromadb.api.types import ID, Include, IDs import chromadb.errors as errors -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection import chromadb.test.property.strategies as strategies from hypothesis.stateful import ( @@ -64,7 +64,7 @@ class EmbeddingStateMachine(RuleBasedStateMachine): collection: Collection embedding_ids: Bundle[ID] = Bundle("embedding_ids") - def __init__(self, api: API): + def __init__(self, api: ServerAPI): super().__init__() self.api = api self._rules_strategy = strategies.DeterministicRuleStrategy(self) # type: ignore @@ -294,13 +294,13 @@ def on_state_change(self, new_state: str) -> None: pass -def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: API) -> None: +def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: ServerAPI) -> None: caplog.set_level(logging.ERROR) run_state_machine_as_test(lambda: EmbeddingStateMachine(api)) # type: ignore print_traces() -def test_multi_add(api: API) -> None: +def test_multi_add(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") coll.add(ids=["a"], embeddings=[[0.0]]) @@ -319,7 +319,7 @@ def test_multi_add(api: API) -> None: assert coll.count() == 0 -def test_dup_add(api: API) -> None: +def test_dup_add(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") with pytest.raises(errors.DuplicateIDError): @@ -328,7 +328,7 @@ def test_dup_add(api: API) -> None: coll.upsert(ids=["a", "a"], embeddings=[[0.0], [1.1]]) -def test_query_without_add(api: API) -> None: +def test_query_without_add(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") fields: Include = ["documents", "metadatas", "embeddings", "distances"] @@ -343,7 +343,7 @@ def test_query_without_add(api: API) -> None: assert all([len(result) == 0 for result in field_results]) -def test_get_non_existent(api: API) -> None: +def test_get_non_existent(api: ServerAPI) -> None: api.reset() coll = api.create_collection(name="foo") result = coll.get(ids=["a"], include=["documents", "metadatas", "embeddings"]) @@ -355,7 +355,7 @@ def test_get_non_existent(api: API) -> None: # TODO: Use SQL escaping correctly internally @pytest.mark.xfail(reason="We don't properly escape SQL internally, causing problems") -def test_escape_chars_in_ids(api: API) -> None: +def test_escape_chars_in_ids(api: ServerAPI) -> None: api.reset() id = "\x1f" coll = api.create_collection(name="foo") @@ -375,7 +375,7 @@ def test_escape_chars_in_ids(api: API) -> None: {"where_document": {}, "where": {}}, ], ) -def test_delete_empty_fails(api: API, kwargs: dict): +def test_delete_empty_fails(api: ServerAPI, kwargs: dict): api.reset() coll = api.create_collection(name="foo") with pytest.raises(Exception) as e: @@ -398,7 +398,7 @@ def test_delete_empty_fails(api: API, kwargs: dict): }, ], ) -def test_delete_success(api: API, kwargs: dict): +def test_delete_success(api: ServerAPI, kwargs: dict): api.reset() coll = api.create_collection(name="foo") # Should not raise diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index ddcdefb0ed3..e55e5d18cf5 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, cast from hypothesis import given, settings, HealthCheck import pytest -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.test.property import invariants from chromadb.api.types import ( Document, @@ -165,7 +165,7 @@ def _filter_embedding_set( filters=st.lists(strategies.filters(collection_st, recordset_st), min_size=1), ) def test_filterable_metadata_get( - caplog, api: API, collection: strategies.Collection, record_set, filters + caplog, api: ServerAPI, collection: strategies.Collection, record_set, filters ) -> None: caplog.set_level(logging.ERROR) @@ -204,7 +204,7 @@ def test_filterable_metadata_get( ) def test_filterable_metadata_query( caplog: pytest.LogCaptureFixture, - api: API, + api: ServerAPI, collection: strategies.Collection, record_set: strategies.RecordSet, filters: List[strategies.Filter], @@ -257,7 +257,7 @@ def test_filterable_metadata_query( assert len(result_ids.intersection(expected_ids)) == len(result_ids) -def test_empty_filter(api: API) -> None: +def test_empty_filter(api: ServerAPI) -> None: """Test that a filter where no document matches returns an empty result""" api.reset() coll = api.create_collection(name="test") @@ -291,7 +291,7 @@ def test_empty_filter(api: API) -> None: assert res["metadatas"] == [[], []] -def test_boolean_metadata(api: API) -> None: +def test_boolean_metadata(api: ServerAPI) -> None: """Test that metadata with boolean values is correctly filtered""" api.reset() coll = api.create_collection(name="test") diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index ea95f684f60..e7b1f7017d1 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -6,7 +6,7 @@ import hypothesis.strategies as st import pytest import chromadb -from chromadb.api import API +from chromadb.api import ClientAPI, ServerAPI from chromadb.config import Settings, System import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants @@ -26,7 +26,7 @@ import shutil import tempfile -CreatePersistAPI = Callable[[], API] +CreatePersistAPI = Callable[[], ServerAPI] configurations = [ Settings( @@ -71,7 +71,7 @@ def test_persist( embeddings_strategy: strategies.RecordSet, ) -> None: system_1 = System(settings) - api_1 = system_1.instance(API) + api_1 = system_1.instance(ServerAPI) system_1.start() api_1.reset() @@ -103,7 +103,7 @@ def test_persist( del system_1 system_2 = System(settings) - api_2 = system_2.instance(API) + api_2 = system_2.instance(ServerAPI) system_2.start() coll = api_2.get_collection( @@ -133,7 +133,7 @@ def load_and_check( ) -> None: try: system = System(settings) - api = system.instance(API) + api = system.instance(ServerAPI) system.start() coll = api.get_collection( @@ -157,7 +157,7 @@ class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates): class PersistEmbeddingsStateMachine(EmbeddingStateMachine): - def __init__(self, api: API, settings: Settings): + def __init__(self, api: ClientAPI, settings: Settings): self.api = api self.settings = settings self.last_persist_delay = 10 diff --git a/chromadb/test/stress/test_many_collections.py b/chromadb/test/stress/test_many_collections.py index 7e65c4b790d..29951fa452a 100644 --- a/chromadb/test/stress/test_many_collections.py +++ b/chromadb/test/stress/test_many_collections.py @@ -1,11 +1,11 @@ from typing import List import numpy as np -from chromadb.api import API +from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection -def test_many_collections(api: API) -> None: +def test_many_collections(api: ServerAPI) -> None: """Test that we can create a large number of collections and that the system # remains responsive.""" api.reset() diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 8a12a1d9735..ed3c87ee682 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -21,7 +21,7 @@ @pytest.fixture def local_persist_api(): - yield chromadb.Client( + client = chromadb.Client( Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", @@ -33,6 +33,8 @@ def local_persist_api(): persist_directory=persist_dir, ), ) + yield client + client.clear_system_cache() if os.path.exists(persist_dir): shutil.rmtree(persist_dir, ignore_errors=True) @@ -40,7 +42,7 @@ def local_persist_api(): # https://docs.pytest.org/en/6.2.x/fixture.html#fixtures-can-be-requested-more-than-once-per-test-return-values-are-cached @pytest.fixture def local_persist_api_cache_bust(): - yield chromadb.Client( + client = chromadb.Client( Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", @@ -52,6 +54,8 @@ def local_persist_api_cache_bust(): persist_directory=persist_dir, ), ) + yield client + client.clear_system_cache() if os.path.exists(persist_dir): shutil.rmtree(persist_dir, ignore_errors=True) diff --git a/chromadb/test/test_chroma.py b/chromadb/test/test_chroma.py index 42b14411519..9d88ea8cc49 100644 --- a/chromadb/test/test_chroma.py +++ b/chromadb/test/test_chroma.py @@ -47,19 +47,21 @@ class GetAPITest(unittest.TestCase): @patch("chromadb.api.segment.SegmentAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_local(self, mock_api: Mock) -> None: - chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) + client = chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) assert mock_api.called + client.clear_system_cache() @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_local_db(self, mock_db: Mock) -> None: - chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) + client = chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) assert mock_db.called + client.clear_system_cache() @patch("chromadb.api.fastapi.FastAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_fastapi(self, mock: Mock) -> None: - chromadb.Client( + client = chromadb.Client( chromadb.config.Settings( chroma_api_impl="chromadb.api.fastapi.FastAPI", persist_directory="./foo", @@ -68,6 +70,7 @@ def test_fastapi(self, mock: Mock) -> None: ) ) assert mock.called + client.clear_system_cache() @patch("chromadb.api.fastapi.FastAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) @@ -78,7 +81,7 @@ def test_settings_pass_to_fastapi(self, mock: Mock) -> None: chroma_server_http_port="80", chroma_server_headers={"foo": "bar"}, ) - chromadb.Client(settings) + client = chromadb.Client(settings) # Check that the mock was called assert mock.called @@ -93,11 +96,12 @@ def test_settings_pass_to_fastapi(self, mock: Mock) -> None: # Check if the settings passed to the mock match the settings we used # raise Exception(passed_settings.settings) assert passed_settings.settings == settings + client.clear_system_cache() def test_legacy_values() -> None: with pytest.raises(ValueError): - chromadb.Client( + client = chromadb.Client( chromadb.config.Settings( chroma_api_impl="chromadb.api.local.LocalAPI", persist_directory="./foo", @@ -105,3 +109,4 @@ def test_legacy_values() -> None: chroma_server_http_port="80", ) ) + client.clear_system_cache() diff --git a/chromadb/test/test_client.py b/chromadb/test/test_client.py index 1164e1e699d..d4f1de9ae9f 100644 --- a/chromadb/test/test_client.py +++ b/chromadb/test/test_client.py @@ -1,37 +1,44 @@ +from typing import Generator import chromadb -from chromadb.api import API +from chromadb.api import ClientAPI import chromadb.server.fastapi import pytest import tempfile @pytest.fixture -def ephemeral_api() -> API: - return chromadb.EphemeralClient() +def ephemeral_api() -> Generator[ClientAPI, None, None]: + client = chromadb.EphemeralClient() + yield client + client.clear_system_cache() @pytest.fixture -def persistent_api() -> API: - return chromadb.PersistentClient( +def persistent_api() -> Generator[ClientAPI, None, None]: + client = chromadb.PersistentClient( path=tempfile.gettempdir() + "/test_server", ) + yield client + client.clear_system_cache() @pytest.fixture -def http_api() -> API: - return chromadb.HttpClient() +def http_api() -> Generator[ClientAPI, None, None]: + client = chromadb.HttpClient() + yield client + client.clear_system_cache() -def test_ephemeral_client(ephemeral_api: API) -> None: +def test_ephemeral_client(ephemeral_api: ClientAPI) -> None: settings = ephemeral_api.get_settings() assert settings.is_persistent is False -def test_persistent_client(persistent_api: API) -> None: +def test_persistent_client(persistent_api: ClientAPI) -> None: settings = persistent_api.get_settings() assert settings.is_persistent is True -def test_http_client(http_api: API) -> None: +def test_http_client(http_api: ClientAPI) -> None: settings = http_api.get_settings() assert settings.chroma_api_impl == "chromadb.api.fastapi.FastAPI" diff --git a/chromadb/test/test_multithreaded.py b/chromadb/test/test_multithreaded.py index 57c259dad99..c0b05e88324 100644 --- a/chromadb/test/test_multithreaded.py +++ b/chromadb/test/test_multithreaded.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, cast import numpy as np -from chromadb.api import API +from chromadb.api import ServerAPI import chromadb.test.property.invariants as invariants from chromadb.test.property.strategies import RecordSet from chromadb.test.property.strategies import test_hnsw_config @@ -37,7 +37,7 @@ def generate_record_set(N: int, D: int) -> RecordSet: # Hypothesis is bad at generating large datasets so we manually generate data in # this test to test multithreaded add with larger datasets -def _test_multithreaded_add(api: API, N: int, D: int, num_workers: int) -> None: +def _test_multithreaded_add(api: ServerAPI, N: int, D: int, num_workers: int) -> None: records_set = generate_record_set(N, D) ids = records_set["ids"] embeddings = records_set["embeddings"] @@ -95,7 +95,9 @@ def _test_multithreaded_add(api: API, N: int, D: int, num_workers: int) -> None: ) -def _test_interleaved_add_query(api: API, N: int, D: int, num_workers: int) -> None: +def _test_interleaved_add_query( + api: ServerAPI, N: int, D: int, num_workers: int +) -> None: """Test that will use multiple threads to interleave operations on the db and verify they work correctly""" api.reset() @@ -207,14 +209,14 @@ def perform_operation( ) -def test_multithreaded_add(api: API) -> None: +def test_multithreaded_add(api: ServerAPI) -> None: for i in range(3): num_workers = random.randint(2, multiprocessing.cpu_count() * 2) N, D = generate_data_shape() _test_multithreaded_add(api, N, D, num_workers) -def test_interleaved_add_query(api: API) -> None: +def test_interleaved_add_query(api: ServerAPI) -> None: for i in range(3): num_workers = random.randint(2, multiprocessing.cpu_count() * 2) N, D = generate_data_shape() diff --git a/chromadb/utils/batch_utils.py b/chromadb/utils/batch_utils.py index c8c1ac1e476..9c588270f25 100644 --- a/chromadb/utils/batch_utils.py +++ b/chromadb/utils/batch_utils.py @@ -1,5 +1,5 @@ from typing import Optional, Tuple, List -from chromadb.api import API +from chromadb.api import BaseAPI from chromadb.api.types import ( Documents, Embeddings, @@ -9,7 +9,7 @@ def create_batches( - api: API, + api: BaseAPI, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, From 8dad8e0264fbef357f020ddc03568285ed5b5e07 Mon Sep 17 00:00:00 2001 From: hammadb Date: Wed, 18 Oct 2023 21:23:38 -0700 Subject: [PATCH 03/22] Simplify --- chromadb/api/__init__.py | 130 ++------------ chromadb/api/client.py | 114 ++++++------ chromadb/api/fastapi.py | 38 ++-- chromadb/api/models/Collection.py | 2 - chromadb/api/segment.py | 40 ++--- chromadb/db/impl/grpc/client.py | 30 +++- chromadb/db/impl/grpc/server.py | 115 ++++++++++-- chromadb/db/impl/sqlite.py | 6 +- chromadb/db/mixins/sysdb.py | 94 +++++++++- chromadb/db/system.py | 20 ++- .../sysdb/00004-tenants-databases.sqlite.sql | 28 ++- chromadb/proto/coordinator_pb2.py | 52 +++--- chromadb/proto/coordinator_pb2.pyi | 34 +++- chromadb/proto/coordinator_pb2_grpc.py | 45 +++++ chromadb/test/db/test_system.py | 170 ++++++++++++++++++ idl/chromadb/proto/coordinator.proto | 15 +- 16 files changed, 649 insertions(+), 284 deletions(-) diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 2287bd929d2..28a62fcb5be 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -419,7 +419,19 @@ def clear_system_cache() -> None: pass -class ServerAPI(BaseAPI, Component): +class AdminAPI(ABC): + @abstractmethod + def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + """Create a new database. + + Args: + database: The name of the database to create. + + """ + pass + + +class ServerAPI(BaseAPI, AdminAPI, Component): """An API instance that extends the relevant Base API methods by passing in a tenant and database. This is the root component of the Chroma System""" @@ -473,8 +485,6 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> None: pass @@ -487,117 +497,3 @@ def delete_collection( database: str = DEFAULT_DATABASE, ) -> None: pass - - # - # ITEM METHODS - # - - @abstractmethod - @override - def _add( - self, - ids: IDs, - collection_id: UUID, - embeddings: Embeddings, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> bool: - pass - - @abstractmethod - @override - def _update( - self, - collection_id: UUID, - ids: IDs, - embeddings: Optional[Embeddings] = None, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> bool: - pass - - @abstractmethod - @override - def _upsert( - self, - collection_id: UUID, - ids: IDs, - embeddings: Embeddings, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> bool: - pass - - @abstractmethod - @override - def _count( - self, - collection_id: UUID, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> int: - pass - - @abstractmethod - @override - def _peek( - self, - collection_id: UUID, - n: int = 10, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> GetResult: - pass - - @abstractmethod - @override - def _get( - self, - collection_id: UUID, - ids: Optional[IDs] = None, - where: Optional[Where] = {}, - sort: Optional[str] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - page: Optional[int] = None, - page_size: Optional[int] = None, - where_document: Optional[WhereDocument] = {}, - include: Include = ["embeddings", "metadatas", "documents"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> GetResult: - pass - - @abstractmethod - @override - def _delete( - self, - collection_id: UUID, - ids: Optional[IDs], - where: Optional[Where] = {}, - where_document: Optional[WhereDocument] = {}, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> IDs: - pass - - @abstractmethod - @override - def _query( - self, - collection_id: UUID, - query_embeddings: Embeddings, - n_results: int = 10, - where: Where = {}, - where_document: WhereDocument = {}, - include: Include = ["embeddings", "metadatas", "documents", "distances"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, - ) -> QueryResult: - pass diff --git a/chromadb/api/client.py b/chromadb/api/client.py index c007495705b..24b087cc30c 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -2,7 +2,7 @@ from uuid import UUID from overrides import override -from chromadb.api import ClientAPI, ServerAPI +from chromadb.api import AdminAPI, ClientAPI, ServerAPI from chromadb.api.types import ( CollectionMetadata, Documents, @@ -23,51 +23,17 @@ import chromadb.utils.embedding_functions as ef -class Client(ClientAPI): - """A client for Chroma. This is the main entrypoint for interacting with Chroma. - A client internally stores its tenant and database and proxies calls to a - Server API instance of Chroma. It treats the Server API and corresponding System - as a singleton, so multiple clients connecting to the same resource will share the - same API instance. - - Client implementations should be implement their own API-caching strategies. - """ - - tenant: str = DEFAULT_TENANT - database: str = DEFAULT_DATABASE - +class SharedSystemClient: _identifer_to_system: ClassVar[Dict[str, System]] = {} _identifier: str - _server: ServerAPI # region Initialization - def __new__( - cls, - tenant: str = "default", - database: str = "default", - settings: Settings = Settings(), - ) -> "Client": - identifier = cls._get_identifier_from_settings(settings) - cls._create_system_if_not_exists(identifier, settings) - instance = super().__new__(cls) - return instance - def __init__( self, - tenant: str = "default", - database: str = "default", settings: Settings = Settings(), ) -> None: - self.tenant = tenant - self.database = database - self._identifier = self._get_identifier_from_settings(settings) - - # Get the root system component we want to interact with - self._server = self._system.instance(ServerAPI) - - # Submit event for a client start - telemetry_client = self._system.instance(Telemetry) - telemetry_client.capture(ClientStartEvent()) + self._identifier = SharedSystemClient._get_identifier_from_settings(settings) + SharedSystemClient._create_system_if_not_exists(self._identifier, settings) @classmethod def _create_system_if_not_exists( @@ -116,13 +82,48 @@ def _get_identifier_from_settings(settings: Settings) -> str: return identifier @staticmethod - @override def clear_system_cache() -> None: - Client._identifer_to_system = {} + SharedSystemClient._identifer_to_system = {} @property def _system(self) -> System: - return self._identifer_to_system[self._identifier] + return SharedSystemClient._identifer_to_system[self._identifier] + + # endregion + + +class Client(SharedSystemClient, ClientAPI): + """A client for Chroma. This is the main entrypoint for interacting with Chroma. + A client internally stores its tenant and database and proxies calls to a + Server API instance of Chroma. It treats the Server API and corresponding System + as a singleton, so multiple clients connecting to the same resource will share the + same API instance. + + Client implementations should be implement their own API-caching strategies. + """ + + tenant: str = DEFAULT_TENANT + database: str = DEFAULT_DATABASE + + _server: ServerAPI + + # region Initialization + def __init__( + self, + tenant: str = "default", + database: str = "default", + settings: Settings = Settings(), + ) -> None: + super().__init__(settings=settings) + self.tenant = tenant + self.database = database + + # Get the root system component we want to interact with + self._server = self._system.instance(ServerAPI) + + # Submit event for a client start + telemetry_client = self._system.instance(Telemetry) + telemetry_client.capture(ClientStartEvent()) # endregion @@ -150,6 +151,7 @@ def create_collection( embedding_function=embedding_function, tenant=self.tenant, database=self.database, + get_or_create=get_or_create, ) @override @@ -191,8 +193,6 @@ def _modify( id=id, new_name=new_name, new_metadata=new_metadata, - tenant=self.tenant, - database=self.database, ) @override @@ -225,8 +225,6 @@ def _add( embeddings=embeddings, metadatas=metadatas, documents=documents, - tenant=self.tenant, - database=self.database, ) @override @@ -244,8 +242,6 @@ def _update( embeddings=embeddings, metadatas=metadatas, documents=documents, - tenant=self.tenant, - database=self.database, ) @override @@ -263,16 +259,12 @@ def _upsert( embeddings=embeddings, metadatas=metadatas, documents=documents, - tenant=self.tenant, - database=self.database, ) @override def _count(self, collection_id: UUID) -> int: return self._server._count( collection_id=collection_id, - tenant=self.tenant, - database=self.database, ) @override @@ -280,8 +272,6 @@ def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: return self._server._peek( collection_id=collection_id, n=n, - tenant=self.tenant, - database=self.database, ) @override @@ -309,8 +299,6 @@ def _get( page_size=page_size, where_document=where_document, include=include, - tenant=self.tenant, - database=self.database, ) def _delete( @@ -325,8 +313,6 @@ def _delete( ids=ids, where=where, where_document=where_document, - tenant=self.tenant, - database=self.database, ) @override @@ -346,8 +332,6 @@ def _query( where=where, where_document=where_document, include=include, - tenant=self.tenant, - database=self.database, ) @override @@ -380,3 +364,15 @@ def set_tenant(self, tenant: str) -> None: self.tenant = tenant # endregion + + +class AdminClient(AdminAPI, SharedSystemClient): + _server: ServerAPI + + def __init__(self, settings: Settings = Settings()) -> None: + super().__init__(settings) + self._server = self._system.instance(ServerAPI) + + @override + def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + return self._server.create_database(name=name, tenant=tenant) diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 2ee75612c23..784d9e79dbb 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -134,6 +134,19 @@ def heartbeat(self) -> int: raise_chroma_error(resp) return int(resp.json()["nanosecond heartbeat"]) + @override + def create_database( + self, + name: str, + tenant: str = DEFAULT_TENANT, + ) -> None: + """Creates a database""" + resp = self._session.post( + self._api_url + "/databases", + data=json.dumps({"name": name, "tenant": tenant}), + ) + raise_chroma_error(resp) + @override def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE @@ -205,7 +218,12 @@ def get_or_create_collection( database: str = DEFAULT_DATABASE, ) -> Collection: return self.create_collection( - name, metadata, embedding_function, get_or_create=True + name, + metadata, + embedding_function, + get_or_create=True, + tenant=tenant, + database=database, ) @override @@ -214,8 +232,6 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> None: """Updates a collection""" resp = self._session.put( @@ -239,8 +255,6 @@ def delete_collection( def _count( self, collection_id: UUID, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> int: """Returns the number of embeddings in the database""" resp = self._session.get( @@ -254,8 +268,6 @@ def _peek( self, collection_id: UUID, n: int = 10, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> GetResult: return self._get( collection_id, @@ -276,8 +288,6 @@ def _get( page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> GetResult: if page and page_size: offset = (page - 1) * page_size @@ -314,8 +324,6 @@ def _delete( ids: Optional[IDs] = None, where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> IDs: """Deletes embeddings from the database""" resp = self._session.post( @@ -359,8 +367,6 @@ def _add( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: """ Adds a batch of embeddings to the database @@ -380,8 +386,6 @@ def _update( embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: """ Updates a batch of embeddings in the database @@ -403,8 +407,6 @@ def _upsert( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: """ Upserts a batch of embeddings in the database @@ -427,8 +429,6 @@ def _query( where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents", "distances"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> QueryResult: """Gets the nearest neighbors of a single embedding""" resp = self._session.post( diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index d1f4b296712..f605d9d9d84 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -43,8 +43,6 @@ class Collection(BaseModel): _client: "ServerAPI" = PrivateAttr() _embedding_function: Optional[EmbeddingFunction] = PrivateAttr() - # TODO: STORE THE TENANT AND NAMESPACE IN THE COLLECTION OBJECT - def __init__( self, client: "ServerAPI", diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 27b53ea41ef..37887d5432a 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -94,6 +94,14 @@ def __init__(self, system: System): def heartbeat(self) -> int: return int(time.time_ns()) + @override + def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + self._sysdb.create_database( + id=uuid4(), + name=name, + tenant=tenant, + ) + # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. @@ -121,6 +129,8 @@ def create_collection( metadata=metadata, dimension=None, get_or_create=get_or_create, + tenant=tenant, + database=database, ) if created: @@ -158,6 +168,8 @@ def get_or_create_collection( metadata=metadata, embedding_function=embedding_function, get_or_create=True, + tenant=tenant, + database=database, ) # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is @@ -171,7 +183,9 @@ def get_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - existing = self._sysdb.get_collections(name=name) + existing = self._sysdb.get_collections( + name=name, tenant=tenant, database=database + ) if existing: return Collection( @@ -191,7 +205,7 @@ def list_collections( database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: collections = [] - db_collections = self._sysdb.get_collections() + db_collections = self._sysdb.get_collections(tenant=tenant, database=database) for db_collection in db_collections: collections.append( Collection( @@ -209,8 +223,6 @@ def _modify( id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> None: if new_name: # backwards compatibility in naming requirements (for now) @@ -235,7 +247,9 @@ def delete_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: - existing = self._sysdb.get_collections(name=name) + existing = self._sysdb.get_collections( + name=name, tenant=tenant, database=database + ) if existing: self._sysdb.delete_collection(existing[0]["id"]) @@ -254,8 +268,6 @@ def _add( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) @@ -293,8 +305,6 @@ def _update( embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) @@ -334,8 +344,6 @@ def _upsert( embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) @@ -370,8 +378,6 @@ def _get( page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> GetResult: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( @@ -439,8 +445,6 @@ def _delete( ids: Optional[IDs] = None, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> IDs: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( @@ -499,8 +503,6 @@ def _delete( def _count( self, collection_id: UUID, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> int: metadata_segment = self._manager.get_segment(collection_id, MetadataReader) return metadata_segment.count() @@ -514,8 +516,6 @@ def _query( where: Where = {}, where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> QueryResult: where = validate_where(where) if where is not None and len(where) > 0 else where where_document = ( @@ -612,8 +612,6 @@ def _peek( self, collection_id: UUID, n: int = 10, - tenant: str = DEFAULT_TENANT, - database: str = DEFAULT_DATABASE, ) -> GetResult: return self._get(collection_id, limit=n) diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index 04d4302062a..fdf78ee957a 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -1,7 +1,7 @@ from typing import List, Optional, Sequence, Tuple, Union, cast from uuid import UUID from overrides import overrides -from chromadb.config import System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System from chromadb.db.base import NotFoundError, UniqueConstraintError from chromadb.db.system import SysDB from chromadb.proto.convert import ( @@ -13,6 +13,7 @@ ) from chromadb.proto.coordinator_pb2 import ( CreateCollectionRequest, + CreateDatabaseRequest, CreateSegmentRequest, DeleteCollectionRequest, DeleteSegmentRequest, @@ -71,6 +72,15 @@ def reset_state(self) -> None: self._sys_db_stub.ResetState(Empty()) return super().reset_state() + @overrides + def create_database( + self, id: UUID, name: str, tenant: str = DEFAULT_TENANT + ) -> None: + request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant) + response = self._sys_db_stub.CreateDatabase(request) + if response.status.code == 409: + raise UniqueConstraintError() + @overrides def create_segment(self, segment: Segment) -> None: proto_segment = to_proto_segment(segment) @@ -164,6 +174,8 @@ def create_collection( metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: request = CreateCollectionRequest( id=id.hex, @@ -171,6 +183,8 @@ def create_collection( metadata=to_proto_update_metadata(metadata) if metadata else None, dimension=dimension, get_or_create=get_or_create, + tenant=tenant, + database=database, ) response = self._sys_db_stub.CreateCollection(request) if response.status.code == 409: @@ -179,9 +193,13 @@ def create_collection( return collection, response.created @overrides - def delete_collection(self, id: UUID) -> None: + def delete_collection( + self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> None: request = DeleteCollectionRequest( id=id.hex, + tenant=tenant, + database=database, ) response = self._sys_db_stub.DeleteCollection(request) if response.status.code == 404: @@ -193,11 +211,15 @@ def get_collections( id: Optional[UUID] = None, topic: Optional[str] = None, name: Optional[str] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: request = GetCollectionsRequest( id=id.hex if id else None, topic=topic, name=name, + tenant=tenant, + database=database, ) response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request) results: List[Collection] = [] @@ -243,7 +265,9 @@ def update_collection( request.ClearField("metadata") request.reset_metadata = True - self._sys_db_stub.UpdateCollection(request) + response = self._sys_db_stub.UpdateCollection(request) + if response.status.code == 404: + raise NotFoundError() def reset_and_wait_for_ready(self) -> None: self._sys_db_stub.ResetState(Empty(), wait_for_ready=True) diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index 436a57fc167..b028ea5a3b4 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -16,6 +16,7 @@ from chromadb.proto.coordinator_pb2 import ( CreateCollectionRequest, CreateCollectionResponse, + CreateDatabaseRequest, CreateSegmentRequest, DeleteCollectionRequest, DeleteSegmentRequest, @@ -43,7 +44,9 @@ class GrpcMockSysDB(SysDBServicer, Component): _server_port: int _assignment_policy: CollectionAssignmentPolicy _segments: Dict[str, Segment] = {} - _collections: Dict[str, Collection] = {} + _tenants_to_databases_to_collections: Dict[ + str, Dict[str, Dict[str, Collection]] + ] = {} def __init__(self, system: System): self._server_port = system.settings.require("chroma_server_grpc_port") @@ -66,9 +69,31 @@ def stop(self) -> None: @overrides def reset_state(self) -> None: self._segments = {} - self._collections = {} + self._tenants_to_databases_to_collections = {} + # Create defaults + self._tenants_to_databases_to_collections["default"] = {} + self._tenants_to_databases_to_collections["default"]["default"] = {} return super().reset_state() + @overrides(check_signature=False) + def CreateDatabase( + self, request: CreateDatabaseRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + tenant = request.tenant + database = request.name + if tenant not in self._tenants_to_databases_to_collections: + return proto.ChromaResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database in self._tenants_to_databases_to_collections[tenant]: + return proto.ChromaResponse( + status=proto.Status( + code=409, reason=f"Database {database} already exists" + ) + ) + self._tenants_to_databases_to_collections[tenant][database] = {} + return proto.ChromaResponse(status=proto.Status(code=200)) + # We are forced to use check_signature=False because the generated proto code # does not have type annotations for the request and response objects. # TODO: investigate generating types for the request and response objects @@ -171,9 +196,47 @@ def CreateCollection( self, request: CreateCollectionRequest, context: grpc.ServicerContext ) -> CreateCollectionResponse: collection_name = request.name - matches = [ - c for c in self._collections.values() if c["name"] == collection_name - ] + tenant = request.tenant + database = request.database + if tenant not in self._tenants_to_databases_to_collections: + return CreateCollectionResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return CreateCollectionResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + + # Check if the collection already exists globally by id + for ( + search_tenant, + databases, + ) in self._tenants_to_databases_to_collections.items(): + for search_database, search_collections in databases.items(): + if request.id in search_collections: + if ( + search_tenant != request.tenant + or search_database != request.database + ): + return CreateCollectionResponse( + status=proto.Status( + code=409, + reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", + ) + ) + elif not request.get_or_create: + # If the id exists for this tenant and database, and we are not doing a get_or_create, then + # we should return a 409 + return CreateCollectionResponse( + status=proto.Status( + code=409, + reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", + ) + ) + + # Check if the collection already exists in this database by name + collections = self._tenants_to_databases_to_collections[tenant][database] + matches = [c for c in collections.values() if c["name"] == collection_name] assert len(matches) <= 1 if len(matches) > 0: if request.get_or_create: @@ -201,7 +264,7 @@ def CreateCollection( dimension=request.dimension, topic=self._assignment_policy.assign_collection(id), ) - self._collections[request.id] = new_collection + collections[request.id] = new_collection return CreateCollectionResponse( status=proto.Status(code=200), collection=to_proto_collection(new_collection), @@ -213,8 +276,19 @@ def DeleteCollection( self, request: DeleteCollectionRequest, context: grpc.ServicerContext ) -> proto.ChromaResponse: collection_id = request.id - if collection_id in self._collections: - del self._collections[collection_id] + tenant = request.tenant + database = request.database + if tenant not in self._tenants_to_databases_to_collections: + return proto.ChromaResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return proto.ChromaResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + collections = self._tenants_to_databases_to_collections[tenant][database] + if collection_id in collections: + del collections[collection_id] return proto.ChromaResponse(status=proto.Status(code=200)) else: return proto.ChromaResponse( @@ -231,8 +305,20 @@ def GetCollections( target_topic = request.topic if request.HasField("topic") else None target_name = request.name if request.HasField("name") else None + tenant = request.tenant + database = request.database + if tenant not in self._tenants_to_databases_to_collections: + return GetCollectionsResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return GetCollectionsResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + collections = self._tenants_to_databases_to_collections[tenant][database] + found_collections = [] - for collection in self._collections.values(): + for collection in collections.values(): if target_id and collection["id"] != target_id: continue if target_topic and collection["topic"] != target_topic: @@ -251,14 +337,21 @@ def UpdateCollection( self, request: UpdateCollectionRequest, context: grpc.ServicerContext ) -> proto.ChromaResponse: id_to_update = UUID(request.id) - if id_to_update.hex not in self._collections: + # Find the collection with this id + collections = {} + for tenant, databases in self._tenants_to_databases_to_collections.items(): + for database, maybe_collections in databases.items(): + if id_to_update.hex in maybe_collections: + collections = maybe_collections + + if id_to_update.hex not in collections: return proto.ChromaResponse( status=proto.Status( code=404, reason=f"Collection {id_to_update} not found" ) ) else: - collection = self._collections[id_to_update.hex] + collection = collections[id_to_update.hex] if request.HasField("topic"): collection["topic"] = request.topic if request.HasField("name"): diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index aed14deb8e2..10036d3c239 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -4,7 +4,6 @@ import chromadb.db.base as base from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue from chromadb.db.mixins.sysdb import SqlSysDB -from chromadb.utils.delete_file import delete_file import sqlite3 from overrides import override import pypika @@ -139,8 +138,9 @@ def reset_state(self) -> None: for row in cur.fetchall(): cur.execute(f"DROP TABLE IF EXISTS {row[0]}") self._conn_pool.close() - if self._is_persistent: - delete_file(self._db_file) + # TODO: clean this up ---- I don't think its correct or needed + # if self._is_persistent: + # delete_file(self._db_file) self.start() super().reset_state() diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index d105918e700..607d8dbea62 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -4,7 +4,7 @@ from pypika import Table, Column from itertools import groupby -from chromadb.config import System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System from chromadb.db.base import ( Cursor, SqlDB, @@ -41,6 +41,38 @@ def start(self) -> None: super().start() self._producer = self._system.instance(Producer) + @override + def create_database( + self, id: UUID, name: str, tenant: str = DEFAULT_TENANT + ) -> None: + with self.tx() as cur: + # Get the tenant id for the tenant name and then insert the database with the id, name and tenant id + databases = Table("databases") + tenants = Table("tenants") + insert_database = ( + self.querybuilder() + .into(databases) + .columns(databases.id, databases.name, databases.tenant_id) + .insert( + ParameterValue(self.uuid_to_db(id)), + ParameterValue(name), + self.querybuilder() + .select(tenants.id) + .from_(tenants) + .where(tenants.id == ParameterValue(tenant)), + ) + ) + sql, params = get_sql(insert_database, self.parameter_format()) + try: + cur.execute(sql, params) + # TODO: database doesn't exist + # TODO: tenant doesn't exist + # TODO: implement unique constraint error lol... + except self.unique_constraint_error() as e: + raise UniqueConstraintError( + f"Database {name} already exists for tenant {tenant}" + ) from e + @override def create_segment(self, segment: Segment) -> None: with self.tx() as cur: @@ -88,11 +120,13 @@ def create_collection( metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: if id is None and not get_or_create: raise ValueError("id must be specified if get_or_create is False") - existing = self.get_collections(name=name) + existing = self.get_collections(name=name, tenant=tenant, database=database) if existing: if get_or_create: collection = existing[0] @@ -101,7 +135,12 @@ def create_collection( collection["id"], metadata=metadata, ) - return self.get_collections(id=collection["id"])[0], False + return ( + self.get_collections( + id=collection["id"], tenant=tenant, database=database + )[0], + False, + ) else: raise UniqueConstraintError(f"Collection {name} already exists") @@ -112,6 +151,8 @@ def create_collection( with self.tx() as cur: collections = Table("collections") + databases = Table("databases") + insert_collection = ( self.querybuilder() .into(collections) @@ -120,12 +161,19 @@ def create_collection( collections.topic, collections.name, collections.dimension, + collections.database_id, ) .insert( ParameterValue(self.uuid_to_db(collection["id"])), ParameterValue(collection["topic"]), ParameterValue(collection["name"]), ParameterValue(collection["dimension"]), + # Get the database id for the database with the given name and tenant + self.querybuilder() + .select(databases.id) + .from_(databases) + .where(databases.name == ParameterValue(database)) + .where(databases.tenant_id == ParameterValue(tenant)), ) ) sql, params = get_sql(insert_collection, self.parameter_format()) @@ -220,8 +268,16 @@ def get_collections( id: Optional[UUID] = None, topic: Optional[str] = None, name: Optional[str] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: """Get collections by name, embedding function and/or metadata""" + + if name is not None and (tenant is None or database is None): + raise ValueError( + "If name is specified, tenant and database must also be specified in order to uniquely identify the collection" + ) + collections_t = Table("collections") metadata_t = Table("collection_metadata") q = ( @@ -248,6 +304,17 @@ def get_collections( if name: q = q.where(collections_t.name == ParameterValue(name)) + if tenant and database: + databases_t = Table("databases") + q = q.where( + collections_t.database_id + == self.querybuilder() + .select(databases_t.id) + .from_(databases_t) + .where(databases_t.name == ParameterValue(database)) + .where(databases_t.tenant_id == ParameterValue(tenant)) + ) + with self.tx() as cur: sql, params = get_sql(q, self.parameter_format()) rows = cur.execute(sql, params).fetchall() @@ -291,13 +358,27 @@ def delete_segment(self, id: UUID) -> None: raise NotFoundError(f"Segment {id} not found") @override - def delete_collection(self, id: UUID) -> None: + def delete_collection( + self, + id: UUID, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: """Delete a topic and all associated segments from the SysDB""" t = Table("collections") + databases_t = Table("databases") q = ( self.querybuilder() .from_(t) .where(t.id == ParameterValue(self.uuid_to_db(id))) + .where( + t.database_id + == self.querybuilder() + .select(databases_t.id) + .from_(databases_t) + .where(databases_t.name == ParameterValue(database)) + .where(databases_t.tenant_id == ParameterValue(tenant)) + ) .delete() ) with self.tx() as cur: @@ -391,7 +472,10 @@ def update_collection( with self.tx() as cur: sql, params = get_sql(q, self.parameter_format()) if sql: # pypika emits a blank string if nothing to do - cur.execute(sql, params) + sql = sql + " RETURNING id" + result = cur.execute(sql, params) + if not result.fetchone(): + raise NotFoundError(f"Collection {id} not found") # TODO: Update to use better semantics where it's possible to update # individual keys without wiping all the existing metadata. diff --git a/chromadb/db/system.py b/chromadb/db/system.py index b4975339fbc..b3506062563 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -10,12 +10,20 @@ Unspecified, UpdateMetadata, ) -from chromadb.config import Component +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component class SysDB(Component): """Data interface for Chroma's System database""" + @abstractmethod + def create_database( + self, id: UUID, name: str, tenant: str = DEFAULT_TENANT + ) -> None: + """Create a new database in the System database. Raises DuplicateError if the Database + already exists.""" + pass + @abstractmethod def create_segment(self, segment: Segment) -> None: """Create a new segment in the System database. Raises DuplicateError if the ID @@ -60,6 +68,8 @@ def create_collection( metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: """Create a new collection any associated resources (Such as the necessary topics) in the SysDB. If get_or_create is True, the @@ -73,7 +83,9 @@ def create_collection( pass @abstractmethod - def delete_collection(self, id: UUID) -> None: + def delete_collection( + self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> None: """Delete a collection, topic, all associated segments and any associate resources from the SysDB and the system at large.""" pass @@ -84,8 +96,10 @@ def get_collections( id: Optional[UUID] = None, topic: Optional[str] = None, name: Optional[str] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: - """Find collections by id, topic or name""" + """Find collections by id, topic or name. If name is provided, tenant and database must also be provided.""" pass @abstractmethod diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql index 1c40e823480..e1386efba42 100644 --- a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -1,19 +1,29 @@ -CREATE TABLE tenants ( +CREATE TABLE tenants ( -- todo: make this idempotent by checking if table exists by using CREATE TABLE IF NOT EXISTS id TEXT PRIMARY KEY, - name TEXT NOT NULL, - UNIQUE (name) -- Maybe not needed since we want to support slug ids + UNIQUE (id) -- Maybe not needed since we want to support slug ids ); CREATE TABLE databases ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, + id TEXT PRIMARY KEY, -- unique globally + name TEXT NOT NULL, -- unique per tenant tenant_id TEXT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, - UNIQUE (name) + UNIQUE (tenant_id, name) -- Ensure that a tenant has only one database with a given name ); -ALTER TABLE collections - ADD COLUMN database_id TEXT NOT NULL REFERENCES databases(id) DEFAULT 'default'; -- ON DELETE CASCADE not supported by sqlite in ALTER TABLE +CREATE TABLE collections_tmp ( + id TEXT PRIMARY KEY, -- unique globally + name TEXT NOT NULL, -- unique per database + topic TEXT NOT NULL, + dimension INTEGER, + database_id TEXT NOT NULL REFERENCES databases(id) ON DELETE CASCADE, + UNIQUE (name, database_id) +); -- Create default tenant and database -INSERT INTO tenants (id, name) VALUES ('default', 'default'); +INSERT INTO tenants (id) VALUES ('default'); -- should ids be uuids? INSERT INTO databases (id, name, tenant_id) VALUES ('default', 'default', 'default'); + +INSERT INTO collections_tmp (id, name, topic, dimension, database_id) + SELECT id, name, topic, dimension, 'default' FROM collections; +DROP TABLE collections; +ALTER TABLE collections_tmp RENAME TO collections; diff --git a/chromadb/proto/coordinator_pb2.py b/chromadb/proto/coordinator_pb2.py index 118405d423a..35459ff1ae6 100644 --- a/chromadb/proto/coordinator_pb2.py +++ b/chromadb/proto/coordinator_pb2.py @@ -15,35 +15,37 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xc3\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"%\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"i\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xb6\x05\n\x05SysDB\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"A\n\x15\x43reateDatabaseRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xe5\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x06 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x07 \x01(\tB\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"G\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\"\x8b\x01\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x04 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x05 \x01(\tB\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\x81\x06\n\x05SysDB\x12I\n\x0e\x43reateDatabase\x12\x1d.chroma.CreateDatabaseRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.coordinator_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATESEGMENTREQUEST']._serialized_start=102 - _globals['_CREATESEGMENTREQUEST']._serialized_end=158 - _globals['_DELETESEGMENTREQUEST']._serialized_start=160 - _globals['_DELETESEGMENTREQUEST']._serialized_end=194 - _globals['_GETSEGMENTSREQUEST']._serialized_start=197 - _globals['_GETSEGMENTSREQUEST']._serialized_end=391 - _globals['_GETSEGMENTSRESPONSE']._serialized_start=393 - _globals['_GETSEGMENTSRESPONSE']._serialized_end=481 - _globals['_UPDATESEGMENTREQUEST']._serialized_start=484 - _globals['_UPDATESEGMENTREQUEST']._serialized_end=734 - _globals['_CREATECOLLECTIONREQUEST']._serialized_start=737 - _globals['_CREATECOLLECTIONREQUEST']._serialized_end=932 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=934 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1049 - _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1051 - _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1088 - _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1090 - _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1195 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1197 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1294 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1297 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1519 - _globals['_SYSDB']._serialized_start=1522 - _globals['_SYSDB']._serialized_end=2216 + _globals['_CREATEDATABASEREQUEST']._serialized_start=102 + _globals['_CREATEDATABASEREQUEST']._serialized_end=167 + _globals['_CREATESEGMENTREQUEST']._serialized_start=169 + _globals['_CREATESEGMENTREQUEST']._serialized_end=225 + _globals['_DELETESEGMENTREQUEST']._serialized_start=227 + _globals['_DELETESEGMENTREQUEST']._serialized_end=261 + _globals['_GETSEGMENTSREQUEST']._serialized_start=264 + _globals['_GETSEGMENTSREQUEST']._serialized_end=458 + _globals['_GETSEGMENTSRESPONSE']._serialized_start=460 + _globals['_GETSEGMENTSRESPONSE']._serialized_end=548 + _globals['_UPDATESEGMENTREQUEST']._serialized_start=551 + _globals['_UPDATESEGMENTREQUEST']._serialized_end=801 + _globals['_CREATECOLLECTIONREQUEST']._serialized_start=804 + _globals['_CREATECOLLECTIONREQUEST']._serialized_end=1033 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=1035 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1150 + _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1152 + _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1223 + _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1226 + _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1365 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1367 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1464 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1467 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1689 + _globals['_SYSDB']._serialized_start=1692 + _globals['_SYSDB']._serialized_end=2461 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/coordinator_pb2.pyi b/chromadb/proto/coordinator_pb2.pyi index 6b9c974e424..a77c0d828a9 100644 --- a/chromadb/proto/coordinator_pb2.pyi +++ b/chromadb/proto/coordinator_pb2.pyi @@ -7,6 +7,16 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor +class CreateDatabaseRequest(_message.Message): + __slots__ = ["id", "name", "tenant"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + tenant: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... + class CreateSegmentRequest(_message.Message): __slots__ = ["segment"] SEGMENT_FIELD_NUMBER: _ClassVar[int] @@ -60,18 +70,22 @@ class UpdateSegmentRequest(_message.Message): def __init__(self, id: _Optional[str] = ..., topic: _Optional[str] = ..., reset_topic: bool = ..., collection: _Optional[str] = ..., reset_collection: bool = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... class CreateCollectionRequest(_message.Message): - __slots__ = ["id", "name", "metadata", "dimension", "get_or_create"] + __slots__ = ["id", "name", "metadata", "dimension", "get_or_create", "tenant", "database"] ID_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] METADATA_FIELD_NUMBER: _ClassVar[int] DIMENSION_FIELD_NUMBER: _ClassVar[int] GET_OR_CREATE_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] id: str name: str metadata: _chroma_pb2.UpdateMetadata dimension: int get_or_create: bool - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., get_or_create: bool = ...) -> None: ... + tenant: str + database: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., get_or_create: bool = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ...) -> None: ... class CreateCollectionResponse(_message.Message): __slots__ = ["collection", "created", "status"] @@ -84,20 +98,28 @@ class CreateCollectionResponse(_message.Message): def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., created: bool = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... class DeleteCollectionRequest(_message.Message): - __slots__ = ["id"] + __slots__ = ["id", "tenant", "database"] ID_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] id: str - def __init__(self, id: _Optional[str] = ...) -> None: ... + tenant: str + database: str + def __init__(self, id: _Optional[str] = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ...) -> None: ... class GetCollectionsRequest(_message.Message): - __slots__ = ["id", "name", "topic"] + __slots__ = ["id", "name", "topic", "tenant", "database"] ID_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] TOPIC_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] id: str name: str topic: str - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ...) -> None: ... + tenant: str + database: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ...) -> None: ... class GetCollectionsResponse(_message.Message): __slots__ = ["collections", "status"] diff --git a/chromadb/proto/coordinator_pb2_grpc.py b/chromadb/proto/coordinator_pb2_grpc.py index a3a1e03227b..d8fe2eb147a 100644 --- a/chromadb/proto/coordinator_pb2_grpc.py +++ b/chromadb/proto/coordinator_pb2_grpc.py @@ -16,6 +16,11 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ + self.CreateDatabase = channel.unary_unary( + "/chroma.SysDB/CreateDatabase", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) self.CreateSegment = channel.unary_unary( "/chroma.SysDB/CreateSegment", request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, @@ -66,6 +71,12 @@ def __init__(self, channel): class SysDBServicer(object): """Missing associated documentation comment in .proto file.""" + def CreateDatabase(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def CreateSegment(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -123,6 +134,11 @@ def ResetState(self, request, context): def add_SysDBServicer_to_server(servicer, server): rpc_method_handlers = { + "CreateDatabase": grpc.unary_unary_rpc_method_handler( + servicer.CreateDatabase, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), "CreateSegment": grpc.unary_unary_rpc_method_handler( servicer.CreateSegment, request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.FromString, @@ -179,6 +195,35 @@ def add_SysDBServicer_to_server(servicer, server): class SysDB(object): """Missing associated documentation comment in .proto file.""" + @staticmethod + def CreateDatabase( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/CreateDatabase", + chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def CreateSegment( request, diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 541643a2ff6..b40bc32108e 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -3,6 +3,7 @@ import tempfile import pytest from typing import Generator, List, Callable, Dict, Union + from chromadb.db.impl.grpc.client import GrpcSysDB from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.types import Collection, Segment, SegmentScope @@ -107,6 +108,7 @@ def sysdb(request: FixtureRequest) -> Generator[SysDB, None, None]: yield next(request.param()) +# region Collection tests def test_create_get_delete_collections(sysdb: SysDB) -> None: sysdb.reset_state() @@ -296,6 +298,171 @@ def test_get_or_create_collection(sysdb: SysDB) -> None: assert result["metadata"] == overlayed_metadata +def test_create_get_delete_database_and_collection(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new database + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create a new collection in the new database + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + ) + + # Create a new collection with the same id but different name in the new database + # and expect an error + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[0]["id"], + name="new_name", + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + get_or_create=False, + ) + + # Create a new collection in the default database + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + ) + + # Check that the new database and collections exist + result = sysdb.get_collections( + name=sample_collections[0]["name"], database="new_database" + ) + assert len(result) == 1 + assert result[0] == sample_collections[0] + + # Check that the collection in the default database exists + result = sysdb.get_collections(name=sample_collections[1]["name"]) + assert len(result) == 1 + assert result[0] == sample_collections[1] + + # Get for a database that doesn't exist with a name that exists in the new database and expect no results + assert ( + len( + sysdb.get_collections( + name=sample_collections[0]["name"], database="fake_db" + ) + ) + == 0 + ) + + # Delete the collection in the new database + sysdb.delete_collection(id=sample_collections[0]["id"], database="new_database") + + # Check that the collection in the new database was deleted + result = sysdb.get_collections(database="new_database") + assert len(result) == 0 + + # Check that the collection in the default database still exists + result = sysdb.get_collections(name=sample_collections[1]["name"]) + assert len(result) == 1 + assert result[0] == sample_collections[1] + + # Delete the deleted collection in the default database and expect an error + with pytest.raises(NotFoundError): + sysdb.delete_collection(id=sample_collections[0]["id"]) + + # Delete the existing collection in the new database and expect an error + with pytest.raises(NotFoundError): + sysdb.delete_collection(id=sample_collections[1]["id"], database="new_database") + + +def test_create_update_with_database(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new database + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create a new collection in the new database + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + ) + + # Create a new collection in the default database + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + ) + + # Update the collection in the default database + sysdb.update_collection( + id=sample_collections[1]["id"], + name="new_name_1", + ) + + # Check that the collection in the default database was updated + result = sysdb.get_collections(id=sample_collections[1]["id"]) + assert len(result) == 1 + assert result[0]["name"] == "new_name_1" + + # Update the collection in the new database + sysdb.update_collection( + id=sample_collections[0]["id"], + name="new_name_0", + ) + + # Check that the collection in the new database was updated + result = sysdb.get_collections( + id=sample_collections[0]["id"], database="new_database" + ) + assert len(result) == 1 + assert result[0]["name"] == "new_name_0" + + # Try to create the collection in the default database in the new database and expect an error + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + database="new_database", + ) + + +def test_get_multiple_with_database(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new database + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create sample collections in the new database + for collection in sample_collections: + sysdb.create_collection( + id=collection["id"], + name=collection["name"], + metadata=collection["metadata"], + dimension=collection["dimension"], + database="new_database", + ) + + # Get all collections in the new database + result = sysdb.get_collections(database="new_database") + assert len(result) == len(sample_collections) + assert sorted(result, key=lambda c: c["name"]) == sample_collections + + # Get all collections in the default database + result = sysdb.get_collections() + assert len(result) == 0 + + +# endregion + +# region Segment tests sample_segments = [ Segment( id=uuid.UUID("00000000-d7d7-413b-92e1-731098a6e492"), @@ -459,3 +626,6 @@ def test_update_segment(sysdb: SysDB) -> None: sysdb.update_segment(segment["id"], metadata=None) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] + + +# endregion diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 2a557f99613..4ff45d97bbd 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -5,6 +5,12 @@ package chroma; import "chromadb/proto/chroma.proto"; import "google/protobuf/empty.proto"; +message CreateDatabaseRequest { + string id = 1; + string name = 2; + string tenant = 3; +} + message CreateSegmentRequest { Segment segment = 1; } @@ -18,7 +24,7 @@ message GetSegmentsRequest { optional string type = 2; optional SegmentScope scope = 3; optional string topic = 4; - optional string collection = 5; + optional string collection = 5; // Collection ID } message GetSegmentsResponse { @@ -49,6 +55,8 @@ message CreateCollectionRequest { optional UpdateMetadata metadata = 3; optional int32 dimension = 4; optional bool get_or_create = 5; + string tenant = 6; + string database = 7; } message CreateCollectionResponse { @@ -59,12 +67,16 @@ message CreateCollectionResponse { message DeleteCollectionRequest { string id = 1; + string tenant = 2; + string database = 3; } message GetCollectionsRequest { optional string id = 1; optional string name = 2; optional string topic = 3; + string tenant = 4; + string database = 5; } message GetCollectionsResponse { @@ -84,6 +96,7 @@ message UpdateCollectionRequest { } service SysDB { + rpc CreateDatabase(CreateDatabaseRequest) returns (ChromaResponse) {} rpc CreateSegment(CreateSegmentRequest) returns (ChromaResponse) {} rpc DeleteSegment(DeleteSegmentRequest) returns (ChromaResponse) {} rpc GetSegments(GetSegmentsRequest) returns (GetSegmentsResponse) {} From 2bd9df122fee2a653c5dc5c3030c190ac88bd240 Mon Sep 17 00:00:00 2001 From: hammadb Date: Thu, 19 Oct 2023 11:54:17 -0700 Subject: [PATCH 04/22] plumbing + concurrency test --- chromadb/api/__init__.py | 3 + chromadb/api/client.py | 42 ++++++++++++- chromadb/api/fastapi.py | 25 ++++++-- chromadb/api/segment.py | 4 +- chromadb/server/fastapi/__init__.py | 55 +++++++++++++---- chromadb/server/fastapi/types.py | 4 ++ chromadb/test/client/test_database_tenant.py | 60 +++++++++++++++++++ .../test_multiple_clients_concurrency.py | 43 +++++++++++++ chromadb/test/conftest.py | 10 +++- 9 files changed, 227 insertions(+), 19 deletions(-) create mode 100644 chromadb/test/client/test_database_tenant.py create mode 100644 chromadb/test/client/test_multiple_clients_concurrency.py diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 28a62fcb5be..6c0e56e5c8f 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -391,6 +391,9 @@ def max_batch_size(self) -> int: class ClientAPI(BaseAPI, ABC): + tenant: str + database: str + @abstractmethod def set_database(self, database: str) -> None: """Set the database for the client. diff --git a/chromadb/api/client.py b/chromadb/api/client.py index 24b087cc30c..3589af0192f 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, Optional, Sequence +from typing import ClassVar, Dict, Optional, Sequence, TypeVar from uuid import UUID from overrides import override @@ -22,6 +22,8 @@ from chromadb.types import Where, WhereDocument import chromadb.utils.embedding_functions as ef +C = TypeVar("C", "SharedSystemClient", "Client", "AdminClient") + class SharedSystemClient: _identifer_to_system: ClassVar[Dict[str, System]] = {} @@ -81,6 +83,20 @@ def _get_identifier_from_settings(settings: Settings) -> str: return identifier + @staticmethod + def _populate_data_from_system(system: System) -> str: + identifier = SharedSystemClient._get_identifier_from_settings(system.settings) + SharedSystemClient._identifer_to_system[identifier] = system + return identifier + + @classmethod + def from_system(cls, system: System) -> "SharedSystemClient": + """Create a client from an existing system. This is useful for testing and debugging.""" + + SharedSystemClient._populate_data_from_system(system) + instance = cls(system.settings) + return instance + @staticmethod def clear_system_cache() -> None: SharedSystemClient._identifer_to_system = {} @@ -125,6 +141,18 @@ def __init__( telemetry_client = self._system.instance(Telemetry) telemetry_client.capture(ClientStartEvent()) + @classmethod + @override + def from_system( + cls, + system: System, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> "Client": + SharedSystemClient._populate_data_from_system(system) + instance = cls(tenant=tenant, database=database, settings=system.settings) + return instance + # endregion # region BaseAPI Methods @@ -366,7 +394,7 @@ def set_tenant(self, tenant: str) -> None: # endregion -class AdminClient(AdminAPI, SharedSystemClient): +class AdminClient(SharedSystemClient, AdminAPI): _server: ServerAPI def __init__(self, settings: Settings = Settings()) -> None: @@ -376,3 +404,13 @@ def __init__(self, settings: Settings = Settings()) -> None: @override def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: return self._server.create_database(name=name, tenant=tenant) + + @classmethod + @override + def from_system( + cls, + system: System, + ) -> "AdminClient": + SharedSystemClient._populate_data_from_system(system) + instance = cls(settings=system.settings) + return instance diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 784d9e79dbb..150edaf7dad 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -143,7 +143,8 @@ def create_database( """Creates a database""" resp = self._session.post( self._api_url + "/databases", - data=json.dumps({"name": name, "tenant": tenant}), + data=json.dumps({"name": name}), + params={"tenant": tenant}, ) raise_chroma_error(resp) @@ -152,7 +153,10 @@ def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> Sequence[Collection]: """Returns a list of all collections""" - resp = self._session.get(self._api_url + "/collections") + resp = self._session.get( + self._api_url + "/collections", + params={"tenant": tenant, "database": database}, + ) raise_chroma_error(resp) json_collections = resp.json() collections = [] @@ -175,8 +179,13 @@ def create_collection( resp = self._session.post( self._api_url + "/collections", data=json.dumps( - {"name": name, "metadata": metadata, "get_or_create": get_or_create} + { + "name": name, + "metadata": metadata, + "get_or_create": get_or_create, + } ), + params={"tenant": tenant, "database": database}, ) raise_chroma_error(resp) resp_json = resp.json() @@ -197,7 +206,10 @@ def get_collection( database: str = DEFAULT_DATABASE, ) -> Collection: """Returns a collection""" - resp = self._session.get(self._api_url + "/collections/" + name) + resp = self._session.get( + self._api_url + "/collections/" + name, + params={"tenant": tenant, "database": database}, + ) raise_chroma_error(resp) resp_json = resp.json() return Collection( @@ -248,7 +260,10 @@ def delete_collection( database: str = DEFAULT_DATABASE, ) -> None: """Deletes a collection""" - resp = self._session.delete(self._api_url + "/collections/" + name) + resp = self._session.delete( + self._api_url + "/collections/" + name, + params={"tenant": tenant, "database": database}, + ) raise_chroma_error(resp) @override diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 37887d5432a..a959f7a2be3 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -252,7 +252,9 @@ def delete_collection( ) if existing: - self._sysdb.delete_collection(existing[0]["id"]) + self._sysdb.delete_collection( + existing[0]["id"], tenant=tenant, database=database + ) for s in self._manager.delete_segments(existing[0]["id"]): self._sysdb.delete_segment(s) if existing and existing[0]["id"] in self._collection_cache: diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index ec8bcac7ea1..ee0e74eb2ea 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -15,7 +15,7 @@ FastAPIChromaAuthMiddleware, FastAPIChromaAuthMiddlewareWrapper, ) -from chromadb.config import Settings, System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System import chromadb.server import chromadb.api from chromadb.api import ServerAPI @@ -26,6 +26,7 @@ ) from chromadb.server.fastapi.types import ( AddEmbedding, + CreateDatabase, DeleteEmbedding, GetEmbedding, QueryEmbedding, @@ -100,7 +101,6 @@ def include_in_schema(path: str) -> bool: super().add_api_route(path, *args, **kwargs) -# TODO: add tenant/namespace to all routes class FastAPI(chromadb.server.Server): def __init__(self, settings: Settings): super().__init__(settings) @@ -134,6 +134,13 @@ def __init__(self, settings: Settings): "/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"] ) + self.router.add_api_route( + "/api/v1/databases", + self.create_database, + methods=["POST"], + response_model=None, + ) + self.router.add_api_route( "/api/v1/collections", self.list_collections, @@ -225,18 +232,39 @@ def heartbeat(self) -> Dict[str, int]: def version(self) -> str: return self._api.get_version() - def list_collections(self) -> Sequence[Collection]: - return self._api.list_collections() - - def create_collection(self, collection: CreateCollection) -> Collection: + def create_database( + self, database: CreateDatabase, tenant: str = DEFAULT_TENANT + ) -> None: + return self._api.create_database(database.name, tenant) + + def list_collections( + self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE + ) -> Sequence[Collection]: + return self._api.list_collections(tenant=tenant, database=database) + + def create_collection( + self, + collection: CreateCollection, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: return self._api.create_collection( name=collection.name, metadata=collection.metadata, get_or_create=collection.get_or_create, + tenant=tenant, + database=database, ) - def get_collection(self, collection_name: str) -> Collection: - return self._api.get_collection(collection_name) + def get_collection( + self, + collection_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> Collection: + return self._api.get_collection( + collection_name, tenant=tenant, database=database + ) def update_collection( self, collection_id: str, collection: UpdateCollection @@ -247,8 +275,15 @@ def update_collection( new_metadata=collection.new_metadata, ) - def delete_collection(self, collection_name: str) -> None: - return self._api.delete_collection(collection_name) + def delete_collection( + self, + collection_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> None: + return self._api.delete_collection( + collection_name, tenant=tenant, database=database + ) def add(self, collection_id: str, add: AddEmbedding) -> None: try: diff --git a/chromadb/server/fastapi/types.py b/chromadb/server/fastapi/types.py index 306f0e5fcb3..a27b6162df6 100644 --- a/chromadb/server/fastapi/types.py +++ b/chromadb/server/fastapi/types.py @@ -59,3 +59,7 @@ class CreateCollection(BaseModel): # type: ignore class UpdateCollection(BaseModel): # type: ignore new_name: Optional[str] = None new_metadata: Optional[CollectionMetadata] = None + + +class CreateDatabase(BaseModel): + name: str diff --git a/chromadb/test/client/test_database_tenant.py b/chromadb/test/client/test_database_tenant.py new file mode 100644 index 00000000000..8a3f4aff479 --- /dev/null +++ b/chromadb/test/client/test_database_tenant.py @@ -0,0 +1,60 @@ +from chromadb.api.client import AdminClient, Client + + +def test_database_tenant_collections(client: Client) -> None: + # Create a new database in the default tenant + admin_client = AdminClient.from_system(client._system) + admin_client.create_database("test_db") + + # Create collections in this new database + client.set_database("test_db") + client.create_collection("collection", metadata={"database": "test_db"}) + + # Create collections in the default database + client.set_database("default") + client.create_collection("collection", metadata={"database": "default"}) + + # List collections in the default database + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].name == "collection" + assert collections[0].metadata == {"database": "default"} + + # List collections in the new database + client.set_database("test_db") + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].metadata == {"database": "test_db"} + + # Update the metadata in both databases to different values + client.set_database("default") + client.list_collections()[0].modify(metadata={"database": "default2"}) + + client.set_database("test_db") + client.list_collections()[0].modify(metadata={"database": "test_db2"}) + + # Validate that the metadata was updated + client.set_database("default") + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].metadata == {"database": "default2"} + + client.set_database("test_db") + collections = client.list_collections() + assert len(collections) == 1 + assert collections[0].metadata == {"database": "test_db2"} + + # Delete the collections and make sure databases are isolated + client.set_database("default") + client.delete_collection("collection") + + collections = client.list_collections() + assert len(collections) == 0 + + client.set_database("test_db") + collections = client.list_collections() + assert len(collections) == 1 + + client.delete_collection("collection") + collections = client.list_collections() + assert len(collections) == 0 diff --git a/chromadb/test/client/test_multiple_clients_concurrency.py b/chromadb/test/client/test_multiple_clients_concurrency.py new file mode 100644 index 00000000000..15e467bd8b3 --- /dev/null +++ b/chromadb/test/client/test_multiple_clients_concurrency.py @@ -0,0 +1,43 @@ +from concurrent.futures import ThreadPoolExecutor +from chromadb.api.client import AdminClient, Client + + +def test_multiple_clients_concurrently(client: Client) -> None: + """Tests running multiple clients, each against their own database, concurrently.""" + admin_client = AdminClient.from_system(client._system) + admin_client.create_database("test_db") + + CLIENT_COUNT = 100 + COLLECTION_COUNT = 500 + + # Each database will create the same collections by name, with differing metadata + databases = [f"db{i}" for i in range(CLIENT_COUNT)] + for database in databases: + admin_client.create_database(database) + + collections = [f"collection{i}" for i in range(COLLECTION_COUNT)] + + # Create N clients, each on a seperate thread, each with their own database + def run_target(n: int) -> None: + thread_client = Client( + tenant="default", database=databases[n], settings=client._system.settings + ) + for collection in collections: + thread_client.create_collection( + collection, metadata={"database": databases[n]} + ) + + with ThreadPoolExecutor(max_workers=CLIENT_COUNT) as executor: + executor.map(run_target, range(CLIENT_COUNT)) + + # Create a final client, which will be used to verify the collections were created + client = Client(settings=client._system.settings) + + # Verify that the collections were created + for database in databases: + client.set_database(database) + seen_collections = client.list_collections() + assert len(seen_collections) == COLLECTION_COUNT + for collection in seen_collections: + assert collection.name in collections + assert collection.metadata == {"database": database} diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 238748bcbbb..b27e1af7c01 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -22,11 +22,12 @@ from typing_extensions import Protocol import chromadb.server.fastapi -from chromadb.api import ServerAPI +from chromadb.api import ClientAPI, ServerAPI from chromadb.config import Settings, System from chromadb.db.mixins import embeddings_queue from chromadb.ingest import Producer from chromadb.types import SeqId, SubmitEmbeddingRecord +from chromadb.api.client import Client as ClientCreator root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) # This will only run when testing @@ -384,6 +385,13 @@ def api(system: System) -> Generator[ServerAPI, None, None]: yield api +@pytest.fixture(scope="function") +def client(system: System) -> Generator[ClientAPI, None, None]: + system.reset_state() + client = ClientCreator.from_system(system) + yield client + + @pytest.fixture(scope="function") def api_wrong_cred( system_wrong_auth: System, From 82cd7b4c0e9f56429a27a24cad03957334a8acaa Mon Sep 17 00:00:00 2001 From: hammadb Date: Thu, 19 Oct 2023 12:49:41 -0700 Subject: [PATCH 05/22] reset clients --- chromadb/test/client/test_database_tenant.py | 1 + chromadb/test/client/test_multiple_clients_concurrency.py | 1 + 2 files changed, 2 insertions(+) diff --git a/chromadb/test/client/test_database_tenant.py b/chromadb/test/client/test_database_tenant.py index 8a3f4aff479..1ce7c895893 100644 --- a/chromadb/test/client/test_database_tenant.py +++ b/chromadb/test/client/test_database_tenant.py @@ -2,6 +2,7 @@ def test_database_tenant_collections(client: Client) -> None: + client.reset() # Create a new database in the default tenant admin_client = AdminClient.from_system(client._system) admin_client.create_database("test_db") diff --git a/chromadb/test/client/test_multiple_clients_concurrency.py b/chromadb/test/client/test_multiple_clients_concurrency.py index 15e467bd8b3..2ee71652c46 100644 --- a/chromadb/test/client/test_multiple_clients_concurrency.py +++ b/chromadb/test/client/test_multiple_clients_concurrency.py @@ -4,6 +4,7 @@ def test_multiple_clients_concurrently(client: Client) -> None: """Tests running multiple clients, each against their own database, concurrently.""" + client.reset() admin_client = AdminClient.from_system(client._system) admin_client.create_database("test_db") From c66327bb815e78de243a9edf07b2b59031589c9b Mon Sep 17 00:00:00 2001 From: hammadb Date: Thu, 19 Oct 2023 14:40:10 -0700 Subject: [PATCH 06/22] Add create_tenant and multi tests --- chromadb/api/__init__.py | 10 +++ chromadb/api/client.py | 4 + chromadb/api/fastapi.py | 8 ++ chromadb/api/segment.py | 6 ++ chromadb/db/impl/grpc/client.py | 8 ++ chromadb/db/impl/grpc/server.py | 12 +++ chromadb/db/mixins/sysdb.py | 19 ++++- chromadb/db/system.py | 10 ++- chromadb/proto/coordinator_pb2.py | 52 ++++++------- chromadb/proto/coordinator_pb2.pyi | 6 ++ chromadb/proto/coordinator_pb2_grpc.py | 45 +++++++++++ chromadb/server/fastapi/__init__.py | 10 +++ .../test_multiple_clients_concurrency.py | 4 +- chromadb/test/db/test_system.py | 74 ++++++++++++++++++- .../test_collections_with_database_tenant.py | 0 idl/chromadb/proto/coordinator.proto | 5 ++ 16 files changed, 241 insertions(+), 32 deletions(-) create mode 100644 chromadb/test/property/test_collections_with_database_tenant.py diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 6c0e56e5c8f..91c8e9517d1 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -433,6 +433,16 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: """ pass + @abstractmethod + def create_tenant(self, name: str) -> None: + """Create a new tenant. + + Args: + tenant: The name of the tenant to create. + + """ + pass + class ServerAPI(BaseAPI, AdminAPI, Component): """An API instance that extends the relevant Base API methods by passing diff --git a/chromadb/api/client.py b/chromadb/api/client.py index 3589af0192f..d7c3c236455 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -405,6 +405,10 @@ def __init__(self, settings: Settings = Settings()) -> None: def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: return self._server.create_database(name=name, tenant=tenant) + @override + def create_tenant(self, name: str) -> None: + return self._server.create_tenant(name=name) + @classmethod @override def from_system( diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 150edaf7dad..f75a8648f17 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -148,6 +148,14 @@ def create_database( ) raise_chroma_error(resp) + @override + def create_tenant(self, name: str) -> None: + resp = self._session.post( + self._api_url + "/tenants", + data=json.dumps({"name": name}), + ) + raise_chroma_error(resp) + @override def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index a959f7a2be3..e43a63c7b31 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -102,6 +102,12 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: tenant=tenant, ) + @override + def create_tenant(self, name: str) -> None: + self._sysdb.create_tenant( + name=name, + ) + # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index fdf78ee957a..967be7c1194 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -15,6 +15,7 @@ CreateCollectionRequest, CreateDatabaseRequest, CreateSegmentRequest, + CreateTenantRequest, DeleteCollectionRequest, DeleteSegmentRequest, GetCollectionsRequest, @@ -81,6 +82,13 @@ def create_database( if response.status.code == 409: raise UniqueConstraintError() + @overrides + def create_tenant(self, name: str) -> None: + request = CreateTenantRequest(name=name) + response = self._sys_db_stub.CreateTenant(request) + if response.status.code == 409: + raise UniqueConstraintError() + @overrides def create_segment(self, segment: Segment) -> None: proto_segment = to_proto_segment(segment) diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index b028ea5a3b4..f308dd7714c 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -94,6 +94,18 @@ def CreateDatabase( self._tenants_to_databases_to_collections[tenant][database] = {} return proto.ChromaResponse(status=proto.Status(code=200)) + @overrides(check_signature=False) + def CreateTenant( + self, request: CreateDatabaseRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + tenant = request.name + if tenant in self._tenants_to_databases_to_collections: + return proto.ChromaResponse( + status=proto.Status(code=409, reason=f"Tenant {tenant} already exists") + ) + self._tenants_to_databases_to_collections[tenant] = {} + return proto.ChromaResponse(status=proto.Status(code=200)) + # We are forced to use check_signature=False because the generated proto code # does not have type annotations for the request and response objects. # TODO: investigate generating types for the request and response objects diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index 607d8dbea62..dbf84f477f8 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -65,14 +65,29 @@ def create_database( sql, params = get_sql(insert_database, self.parameter_format()) try: cur.execute(sql, params) - # TODO: database doesn't exist - # TODO: tenant doesn't exist + # TODO: tenant doesn't exist test # TODO: implement unique constraint error lol... except self.unique_constraint_error() as e: raise UniqueConstraintError( f"Database {name} already exists for tenant {tenant}" ) from e + @override + def create_tenant(self, name: str) -> None: + with self.tx() as cur: + tenants = Table("tenants") + insert_tenant = ( + self.querybuilder() + .into(tenants) + .columns(tenants.id) + .insert(ParameterValue(name)) + ) + sql, params = get_sql(insert_tenant, self.parameter_format()) + try: + cur.execute(sql, params) + except self.unique_constraint_error() as e: + raise UniqueConstraintError(f"Tenant {name} already exists") from e + @override def create_segment(self, segment: Segment) -> None: with self.tx() as cur: diff --git a/chromadb/db/system.py b/chromadb/db/system.py index b3506062563..3fd76feff36 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -20,13 +20,19 @@ class SysDB(Component): def create_database( self, id: UUID, name: str, tenant: str = DEFAULT_TENANT ) -> None: - """Create a new database in the System database. Raises DuplicateError if the Database + """Create a new database in the System database. Raises an Error if the Database already exists.""" pass + @abstractmethod + def create_tenant(self, name: str) -> None: + """Create a new tenant in the System database. The name must be unique. + Raises an Error if the Tenant already exists.""" + pass + @abstractmethod def create_segment(self, segment: Segment) -> None: - """Create a new segment in the System database. Raises DuplicateError if the ID + """Create a new segment in the System database. Raises an Error if the ID already exists.""" pass diff --git a/chromadb/proto/coordinator_pb2.py b/chromadb/proto/coordinator_pb2.py index 35459ff1ae6..7f865bb217f 100644 --- a/chromadb/proto/coordinator_pb2.py +++ b/chromadb/proto/coordinator_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"A\n\x15\x43reateDatabaseRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xe5\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x06 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x07 \x01(\tB\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"G\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\"\x8b\x01\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x04 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x05 \x01(\tB\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\x81\x06\n\x05SysDB\x12I\n\x0e\x43reateDatabase\x12\x1d.chroma.CreateDatabaseRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"A\n\x15\x43reateDatabaseRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"#\n\x13\x43reateTenantRequest\x12\x0c\n\x04name\x18\x02 \x01(\t\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xe5\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x06 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x07 \x01(\tB\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"G\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\"\x8b\x01\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x04 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x05 \x01(\tB\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xc8\x06\n\x05SysDB\x12I\n\x0e\x43reateDatabase\x12\x1d.chroma.CreateDatabaseRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12\x45\n\x0c\x43reateTenant\x12\x1b.chroma.CreateTenantRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -24,28 +24,30 @@ DESCRIPTOR._options = None _globals['_CREATEDATABASEREQUEST']._serialized_start=102 _globals['_CREATEDATABASEREQUEST']._serialized_end=167 - _globals['_CREATESEGMENTREQUEST']._serialized_start=169 - _globals['_CREATESEGMENTREQUEST']._serialized_end=225 - _globals['_DELETESEGMENTREQUEST']._serialized_start=227 - _globals['_DELETESEGMENTREQUEST']._serialized_end=261 - _globals['_GETSEGMENTSREQUEST']._serialized_start=264 - _globals['_GETSEGMENTSREQUEST']._serialized_end=458 - _globals['_GETSEGMENTSRESPONSE']._serialized_start=460 - _globals['_GETSEGMENTSRESPONSE']._serialized_end=548 - _globals['_UPDATESEGMENTREQUEST']._serialized_start=551 - _globals['_UPDATESEGMENTREQUEST']._serialized_end=801 - _globals['_CREATECOLLECTIONREQUEST']._serialized_start=804 - _globals['_CREATECOLLECTIONREQUEST']._serialized_end=1033 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=1035 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1150 - _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1152 - _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1223 - _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1226 - _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1365 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1367 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1464 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1467 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1689 - _globals['_SYSDB']._serialized_start=1692 - _globals['_SYSDB']._serialized_end=2461 + _globals['_CREATETENANTREQUEST']._serialized_start=169 + _globals['_CREATETENANTREQUEST']._serialized_end=204 + _globals['_CREATESEGMENTREQUEST']._serialized_start=206 + _globals['_CREATESEGMENTREQUEST']._serialized_end=262 + _globals['_DELETESEGMENTREQUEST']._serialized_start=264 + _globals['_DELETESEGMENTREQUEST']._serialized_end=298 + _globals['_GETSEGMENTSREQUEST']._serialized_start=301 + _globals['_GETSEGMENTSREQUEST']._serialized_end=495 + _globals['_GETSEGMENTSRESPONSE']._serialized_start=497 + _globals['_GETSEGMENTSRESPONSE']._serialized_end=585 + _globals['_UPDATESEGMENTREQUEST']._serialized_start=588 + _globals['_UPDATESEGMENTREQUEST']._serialized_end=838 + _globals['_CREATECOLLECTIONREQUEST']._serialized_start=841 + _globals['_CREATECOLLECTIONREQUEST']._serialized_end=1070 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=1072 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1187 + _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1189 + _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1260 + _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1263 + _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1402 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1404 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1501 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1504 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1726 + _globals['_SYSDB']._serialized_start=1729 + _globals['_SYSDB']._serialized_end=2569 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/coordinator_pb2.pyi b/chromadb/proto/coordinator_pb2.pyi index a77c0d828a9..047eccba5f2 100644 --- a/chromadb/proto/coordinator_pb2.pyi +++ b/chromadb/proto/coordinator_pb2.pyi @@ -17,6 +17,12 @@ class CreateDatabaseRequest(_message.Message): tenant: str def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... +class CreateTenantRequest(_message.Message): + __slots__ = ["name"] + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + class CreateSegmentRequest(_message.Message): __slots__ = ["segment"] SEGMENT_FIELD_NUMBER: _ClassVar[int] diff --git a/chromadb/proto/coordinator_pb2_grpc.py b/chromadb/proto/coordinator_pb2_grpc.py index d8fe2eb147a..f164b41f19a 100644 --- a/chromadb/proto/coordinator_pb2_grpc.py +++ b/chromadb/proto/coordinator_pb2_grpc.py @@ -21,6 +21,11 @@ def __init__(self, channel): request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, ) + self.CreateTenant = channel.unary_unary( + "/chroma.SysDB/CreateTenant", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) self.CreateSegment = channel.unary_unary( "/chroma.SysDB/CreateSegment", request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, @@ -77,6 +82,12 @@ def CreateDatabase(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def CreateTenant(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def CreateSegment(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -139,6 +150,11 @@ def add_SysDBServicer_to_server(servicer, server): request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.FromString, response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, ), + "CreateTenant": grpc.unary_unary_rpc_method_handler( + servicer.CreateTenant, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), "CreateSegment": grpc.unary_unary_rpc_method_handler( servicer.CreateSegment, request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.FromString, @@ -224,6 +240,35 @@ def CreateDatabase( metadata, ) + @staticmethod + def CreateTenant( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/CreateTenant", + chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def CreateSegment( request, diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index ee0e74eb2ea..f31b127d207 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -141,6 +141,13 @@ def __init__(self, settings: Settings): response_model=None, ) + self.router.add_api_route( + "/api/v1/tenants", + self.create_tenant, + methods=["POST"], + response_model=None, + ) + self.router.add_api_route( "/api/v1/collections", self.list_collections, @@ -237,6 +244,9 @@ def create_database( ) -> None: return self._api.create_database(database.name, tenant) + def create_tenant(self, name: str) -> None: + return self._api.create_tenant(name) + def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> Sequence[Collection]: diff --git a/chromadb/test/client/test_multiple_clients_concurrency.py b/chromadb/test/client/test_multiple_clients_concurrency.py index 2ee71652c46..b62000696fa 100644 --- a/chromadb/test/client/test_multiple_clients_concurrency.py +++ b/chromadb/test/client/test_multiple_clients_concurrency.py @@ -8,8 +8,8 @@ def test_multiple_clients_concurrently(client: Client) -> None: admin_client = AdminClient.from_system(client._system) admin_client.create_database("test_db") - CLIENT_COUNT = 100 - COLLECTION_COUNT = 500 + CLIENT_COUNT = 50 + COLLECTION_COUNT = 10 # Each database will create the same collections by name, with differing metadata databases = [f"db{i}" for i in range(CLIENT_COUNT)] diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index b40bc32108e..26385578bdb 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -8,7 +8,7 @@ from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB -from chromadb.config import Component, System, Settings +from chromadb.config import DEFAULT_TENANT, Component, System, Settings from chromadb.db.system import SysDB from chromadb.db.base import NotFoundError, UniqueConstraintError from pytest import FixtureRequest @@ -460,6 +460,78 @@ def test_get_multiple_with_database(sysdb: SysDB) -> None: assert len(result) == 0 +def test_create_database_with_tenants(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new tenant + sysdb.create_tenant(name="tenant1") + + # Create tenant that already exits and expect an error + with pytest.raises(UniqueConstraintError): + sysdb.create_tenant(name="tenant1") + + with pytest.raises(UniqueConstraintError): + sysdb.create_tenant(name=DEFAULT_TENANT) + + # Create a new database within this tenant and also in the default tenant + sysdb.create_database(id=uuid.uuid4(), name="new_database", tenant="tenant1") + sysdb.create_database(id=uuid.uuid4(), name="new_database") + + # Create a new collection in the new tenant + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + tenant="tenant1", + ) + + # Create a new collection in the default tenant + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + database="new_database", + ) + + # Check that both tenants have the correct collections + result = sysdb.get_collections(database="new_database", tenant="tenant1") + assert len(result) == 1 + assert result[0] == sample_collections[0] + + result = sysdb.get_collections(database="new_database") + assert len(result) == 1 + assert result[0] == sample_collections[1] + + # Creating a collection id that already exists in a tenant that does not have it + # should error + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[0]["id"], + name=sample_collections[0]["name"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], + database="new_database", + ) + + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + id=sample_collections[1]["id"], + name=sample_collections[1]["name"], + metadata=sample_collections[1]["metadata"], + dimension=sample_collections[1]["dimension"], + database="new_database", + tenant="tenant1", + ) + + # A new tenant DOES NOT have a default database. This does not error, instead 0 + # results are returned + result = sysdb.get_collections(database="default", tenant="tenant1") + assert len(result) == 0 + + # endregion # region Segment tests diff --git a/chromadb/test/property/test_collections_with_database_tenant.py b/chromadb/test/property/test_collections_with_database_tenant.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 4ff45d97bbd..c0ecb134fa9 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -11,6 +11,10 @@ message CreateDatabaseRequest { string tenant = 3; } +message CreateTenantRequest { + string name = 2; // Names are globally unique +} + message CreateSegmentRequest { Segment segment = 1; } @@ -97,6 +101,7 @@ message UpdateCollectionRequest { service SysDB { rpc CreateDatabase(CreateDatabaseRequest) returns (ChromaResponse) {} + rpc CreateTenant(CreateTenantRequest) returns (ChromaResponse) {} rpc CreateSegment(CreateSegmentRequest) returns (ChromaResponse) {} rpc DeleteSegment(DeleteSegmentRequest) returns (ChromaResponse) {} rpc GetSegments(GetSegmentsRequest) returns (GetSegmentsResponse) {} From 3553ad92e003b6d9102f6cbc893c2b32f0f1bf43 Mon Sep 17 00:00:00 2001 From: hammadb Date: Thu, 19 Oct 2023 22:10:57 -0700 Subject: [PATCH 07/22] Add tenant/db name verification. Fix fastapi createtenant request. Add property test --- chromadb/api/segment.py | 6 + .../sysdb/00004-tenants-databases.sqlite.sql | 2 +- chromadb/server/fastapi/__init__.py | 5 +- chromadb/server/fastapi/types.py | 4 + chromadb/test/client/test_database_tenant.py | 15 +++ chromadb/test/property/strategies.py | 2 + chromadb/test/property/test_collections.py | 32 ++++-- .../test_collections_with_database_tenant.py | 105 ++++++++++++++++++ 8 files changed, 159 insertions(+), 12 deletions(-) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index e43a63c7b31..8cb95d53080 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -96,6 +96,9 @@ def heartbeat(self) -> int: @override def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: + if len(name) < 3: + raise ValueError("Database name must be at least 3 characters long") + self._sysdb.create_database( id=uuid4(), name=name, @@ -104,6 +107,9 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: @override def create_tenant(self, name: str) -> None: + if len(name) < 3: + raise ValueError("Tenant name must be at least 3 characters long") + self._sysdb.create_tenant( name=name, ) diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql index e1386efba42..be044a8216b 100644 --- a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -20,7 +20,7 @@ CREATE TABLE collections_tmp ( ); -- Create default tenant and database -INSERT INTO tenants (id) VALUES ('default'); -- should ids be uuids? +INSERT INTO tenants (id) VALUES ('default'); -- The default tenant id is 'default' others are UUIDs INSERT INTO databases (id, name, tenant_id) VALUES ('default', 'default', 'default'); INSERT INTO collections_tmp (id, name, topic, dimension, database_id) diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index f31b127d207..d432fcf3028 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -27,6 +27,7 @@ from chromadb.server.fastapi.types import ( AddEmbedding, CreateDatabase, + CreateTenant, DeleteEmbedding, GetEmbedding, QueryEmbedding, @@ -244,8 +245,8 @@ def create_database( ) -> None: return self._api.create_database(database.name, tenant) - def create_tenant(self, name: str) -> None: - return self._api.create_tenant(name) + def create_tenant(self, tenant: CreateTenant) -> None: + return self._api.create_tenant(tenant.name) def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE diff --git a/chromadb/server/fastapi/types.py b/chromadb/server/fastapi/types.py index a27b6162df6..5f1665c91bd 100644 --- a/chromadb/server/fastapi/types.py +++ b/chromadb/server/fastapi/types.py @@ -63,3 +63,7 @@ class UpdateCollection(BaseModel): # type: ignore class CreateDatabase(BaseModel): name: str + + +class CreateTenant(BaseModel): + name: str diff --git a/chromadb/test/client/test_database_tenant.py b/chromadb/test/client/test_database_tenant.py index 1ce7c895893..55672d36aac 100644 --- a/chromadb/test/client/test_database_tenant.py +++ b/chromadb/test/client/test_database_tenant.py @@ -1,3 +1,4 @@ +import pytest from chromadb.api.client import AdminClient, Client @@ -59,3 +60,17 @@ def test_database_tenant_collections(client: Client) -> None: client.delete_collection("collection") collections = client.list_collections() assert len(collections) == 0 + + +def test_min_len_name(client: Client) -> None: + client.reset() + + # Create a new database in the default tenant with a name of length 1 + # and expect an error + admin_client = AdminClient.from_system(client._system) + with pytest.raises(Exception): + admin_client.create_database("a") + + # Create a tenant with a name of length 1 and expect an error + with pytest.raises(Exception): + admin_client.create_tenant("a") diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index e8540ef37aa..3583dadfba9 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -95,10 +95,12 @@ class Record(TypedDict): # TODO: support empty strings everywhere sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_" safe_text = st.text(alphabet=sql_alphabet, min_size=1) +tenant_database_name = st.text(alphabet=sql_alphabet, min_size=3) # Workaround for FastAPI json encoding peculiarities # https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104 safe_text = safe_text.filter(lambda s: not s.startswith("_sa")) +tenant_database_name = tenant_database_name.filter(lambda s: not s.startswith("_sa")) safe_integers = st.integers( min_value=-(2**31), max_value=2**31 - 1 diff --git a/chromadb/test/property/test_collections.py b/chromadb/test/property/test_collections.py index 2d41c62ef80..d7cd8492117 100644 --- a/chromadb/test/property/test_collections.py +++ b/chromadb/test/property/test_collections.py @@ -19,19 +19,19 @@ class CollectionStateMachine(RuleBasedStateMachine): collections: Bundle[strategies.Collection] - model: Dict[str, Optional[types.CollectionMetadata]] + _model: Dict[str, Optional[types.CollectionMetadata]] collections = Bundle("collections") def __init__(self, api: ClientAPI): super().__init__() - self.model = {} + self._model = {} self.api = api @initialize() def initialize(self) -> None: self.api.reset() - self.model = {} + self._model = {} @rule(target=collections, coll=strategies.collections()) def create_coll( @@ -54,7 +54,7 @@ def create_coll( metadata=coll.metadata, embedding_function=coll.embedding_function, ) - self.model[coll.name] = coll.metadata + self.set_model(coll.name, coll.metadata) assert c.name == coll.name assert c.metadata == coll.metadata @@ -74,7 +74,7 @@ def get_coll(self, coll: strategies.Collection) -> None: def delete_coll(self, coll: strategies.Collection) -> None: if coll.name in self.model: self.api.delete_collection(name=coll.name) - del self.model[coll.name] + self.delete_from_model(coll.name) else: with pytest.raises(Exception): self.api.delete_collection(name=coll.name) @@ -140,7 +140,7 @@ def get_or_create_coll( coll.metadata = ( self.model[coll.name] if new_metadata is None else new_metadata ) - self.model[coll.name] = coll.metadata + self.set_model(coll.name, coll.metadata) # Update API c = self.api.get_or_create_collection( @@ -183,7 +183,7 @@ def modify_coll( ) return multiple() coll.metadata = new_metadata - self.model[coll.name] = coll.metadata + self.set_model(coll.name, coll.metadata) if new_name is not None: if new_name in self.model and new_name != coll.name: @@ -191,8 +191,8 @@ def modify_coll( c.modify(metadata=new_metadata, name=new_name) return multiple() - del self.model[coll.name] - self.model[new_name] = coll.metadata + self.delete_from_model(coll.name) + self.set_model(new_name, coll.metadata) coll.name = new_name c.modify(metadata=new_metadata, name=new_name) @@ -202,6 +202,20 @@ def modify_coll( assert c.metadata == coll.metadata return multiple(coll) + def set_model( + self, name: str, metadata: Optional[types.CollectionMetadata] + ) -> None: + model = self.model + model[name] = metadata + + def delete_from_model(self, name: str) -> None: + model = self.model + del model[name] + + @property + def model(self) -> Dict[str, Optional[types.CollectionMetadata]]: + return self._model + def test_collections(caplog: pytest.LogCaptureFixture, api: ClientAPI) -> None: caplog.set_level(logging.ERROR) diff --git a/chromadb/test/property/test_collections_with_database_tenant.py b/chromadb/test/property/test_collections_with_database_tenant.py index e69de29bb2d..7b0b6ecc183 100644 --- a/chromadb/test/property/test_collections_with_database_tenant.py +++ b/chromadb/test/property/test_collections_with_database_tenant.py @@ -0,0 +1,105 @@ +import logging +from typing import Dict, Optional, Tuple +import pytest +from chromadb.api import AdminAPI +import chromadb.api.types as types +from chromadb.api.client import AdminClient, Client +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT +from chromadb.test.property.test_collections import CollectionStateMachine +from hypothesis.stateful import ( + Bundle, + rule, + initialize, + multiple, + run_state_machine_as_test, + MultipleResults, +) +import chromadb.test.property.strategies as strategies + + +class TenantDatabaseCollectionStateMachine(CollectionStateMachine): + """A collection state machine test that includes tenant and database information, + and switches between them.""" + + tenants: Bundle[str] + databases: Bundle[Tuple[str, str]] # database to tenant it belongs to + tenant_to_database_to_model: Dict[ + str, Dict[str, Dict[str, Optional[types.CollectionMetadata]]] + ] + admin_client: AdminAPI + curr_tenant: str + curr_database: str + + tenants = Bundle("tenants") + databases = Bundle("databases") + + def __init__(self, client: Client): + super().__init__(client) + self.api = client + self.admin_client = AdminClient.from_system(client._system) + + @initialize() + def initialize(self) -> None: + self.api.reset() + self.tenant_to_database_to_model = {} + self.curr_tenant = DEFAULT_TENANT + self.curr_database = DEFAULT_DATABASE + self.api.set_tenant(DEFAULT_TENANT) + self.api.set_database(DEFAULT_DATABASE) + self.tenant_to_database_to_model[self.curr_tenant] = {} + self.tenant_to_database_to_model[self.curr_tenant][self.curr_database] = {} + + @rule(target=tenants, name=strategies.tenant_database_name) + def create_tenant(self, name: str) -> MultipleResults[str]: + # Check if tenant already exists + if name in self.tenant_to_database_to_model: + with pytest.raises(Exception): + self.admin_client.create_tenant(name) + return multiple() + + self.admin_client.create_tenant(name) + # When we create a tenant, create a default database for it just for testing + # since the state machine could call collection operations before creating a + # database + self.admin_client.create_database(DEFAULT_DATABASE, tenant=name) + self.tenant_to_database_to_model[name] = {} + self.tenant_to_database_to_model[name][DEFAULT_DATABASE] = {} + return multiple(name) + + @rule(target=databases, name=strategies.tenant_database_name) + def create_database(self, name: str) -> MultipleResults[Tuple[str, str]]: + # If database already exists in current tenant, raise an error + if name in self.tenant_to_database_to_model[self.curr_tenant]: + with pytest.raises(Exception): + self.admin_client.create_database(name, tenant=self.curr_tenant) + return multiple() + + self.admin_client.create_database(name, tenant=self.curr_tenant) + self.tenant_to_database_to_model[self.curr_tenant][name] = {} + return multiple((name, self.curr_tenant)) + + @rule(database=databases) + def set_database_and_tenant(self, database: Dict[str, str]) -> None: + # Get a database and switch to the database and the tenant it belongs to + database_name = database[0] + tenant_name = database[1] + self.api.set_tenant(tenant_name) + self.api.set_database(database_name) + self.curr_database = database_name + self.curr_tenant = tenant_name + + @rule(tenant=tenants) + def set_tenant(self, tenant: str) -> None: + self.api.set_tenant(tenant) + self.api.set_database(DEFAULT_DATABASE) + self.curr_tenant = tenant + self.curr_database = DEFAULT_DATABASE + + @property + def model(self) -> Dict[str, Optional[types.CollectionMetadata]]: + return self.tenant_to_database_to_model[self.curr_tenant][self.curr_database] + + +def test_collections(caplog: pytest.LogCaptureFixture, client: Client) -> None: + caplog.set_level(logging.ERROR) + run_state_machine_as_test(lambda: TenantDatabaseCollectionStateMachine(client)) # type: ignore From 85e8b0243be6d64a37018edec9fe494018e07f81 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 08:22:33 -0700 Subject: [PATCH 08/22] Commit get behavior --- chromadb/__init__.py | 64 +++++-------- chromadb/api/__init__.py | 26 +++++- chromadb/api/client.py | 10 ++- chromadb/api/fastapi.py | 27 ++++++ chromadb/api/segment.py | 8 ++ chromadb/db/impl/grpc/client.py | 26 ++++++ chromadb/db/impl/grpc/server.py | 43 +++++++++ chromadb/db/mixins/sysdb.py | 42 +++++++++ chromadb/db/system.py | 13 +++ .../sysdb/00004-tenants-databases.sqlite.sql | 10 +-- chromadb/proto/chroma_pb2.py | 74 +++++++-------- chromadb/proto/chroma_pb2.pyi | 16 ++++ chromadb/proto/coordinator_pb2.py | 62 +++++++------ chromadb/proto/coordinator_pb2.pyi | 30 +++++++ chromadb/proto/coordinator_pb2_grpc.py | 90 +++++++++++++++++++ chromadb/test/db/test_system.py | 32 +++++++ chromadb/types.py | 10 +++ idl/chromadb/proto/chroma.proto | 10 +++ idl/chromadb/proto/coordinator.proto | 21 +++++ 19 files changed, 500 insertions(+), 114 deletions(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 8360a69498c..7aab0e3d79e 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -1,9 +1,10 @@ from typing import Dict import logging from chromadb.api.client import Client as ClientCreator +from chromadb.api.client import AdminClient as AdminClientCreator import chromadb.config -from chromadb.config import Settings -from chromadb.api import ClientAPI +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings +from chromadb.api import AdminAPI, ClientAPI from chromadb.api.models.Collection import Collection from chromadb.api.types import ( CollectionMetadata, @@ -91,21 +92,25 @@ def get_settings() -> Settings: return __settings -def EphemeralClient(settings: Settings = Settings()) -> ClientAPI: +def EphemeralClient( + settings: Settings = Settings(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, +) -> ClientAPI: """ Creates an in-memory instance of Chroma. This is useful for testing and development, but not recommended for production use. """ settings.is_persistent = False - return Client(settings) + return ClientCreator(settings=settings, tenant=tenant, database=database) def PersistentClient( path: str = "./chroma", - tenant: str = "default", - database: str = "default", settings: Settings = Settings(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> ClientAPI: """ Creates a persistent instance of Chroma that saves to disk. This is useful for @@ -125,9 +130,9 @@ def HttpClient( port: str = "8000", ssl: bool = False, headers: Dict[str, str] = {}, - tenant: str = "default", - database: str = "default", settings: Settings = Settings(), + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> ClientAPI: """ Creates a client that connects to a remote Chroma server. This supports @@ -150,44 +155,15 @@ def HttpClient( return ClientCreator(tenant=tenant, database=database, settings=settings) -# TODO: replace default tenant and database strings with constants +def AdminClient(settings: Settings = Settings()) -> AdminAPI: + return AdminClientCreator(settings=settings) + + def Client( - settings: Settings = __settings, tenant: str = "default", database: str = "default" + settings: Settings = __settings, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, ) -> ClientAPI: """Return a running chroma.API instance""" - # Change this to actually check if an "API" instance already exists, wrap it in a - # tenant/database aware "Client", and return it - # this way we can support multiple clients in the same process but using the same - # chroma instance - - # API is thread safe, so we can just return the same instance - # This way a "Client" will just be a wrapper around an API instance that is - # tenant/database aware - - # To do this we will - # 1. Have a global dict of API instances, keyed by path - # 2. When a client is requested, check if one exists in the dict, and if so check if its - # settings match the requested settings - # 3. If the settings match, construct a new Client that wraps the existing API instance with - # the tenant/database - # 4. If the settings don't match, error out because we don't support changing the settings - # got a given database - # 5. If no client exists in the dict, create a new API instance, wrap it in a Client, and - # add it to the dict - - # The hierarchy then becomes - # For local - # Path -> Tenant -> Namespace -> API - # For remote - # Host -> Tenant -> Namespace -> API - - # A given API for a path is a singleton, and is shared between all tenants and namespaces - # for that path - - # A DB exists at a path or host, and has tenants and namespaces - - # All our tests currently use system.instance(API) assuming thats the root object - # This is likely fine, - return ClientCreator(tenant=tenant, database=database, settings=settings) diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 91c8e9517d1..01cdcf943ce 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -19,6 +19,7 @@ WhereDocument, ) from chromadb.config import Component, Settings +from chromadb.types import Database, Tenant import chromadb.utils.embedding_functions as ef @@ -425,7 +426,7 @@ def clear_system_cache() -> None: class AdminAPI(ABC): @abstractmethod def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: - """Create a new database. + """Create a new database. Raises an error if the database already exists. Args: database: The name of the database to create. @@ -433,9 +434,20 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: """ pass + @abstractmethod + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + """Get a database. Raises an error if the database does not exist. + + Args: + database: The name of the database to get. + tenant: The tenant of the database to get. + + """ + pass + @abstractmethod def create_tenant(self, name: str) -> None: - """Create a new tenant. + """Create a new tenant. Raises an error if the tenant already exists. Args: tenant: The name of the tenant to create. @@ -443,6 +455,16 @@ def create_tenant(self, name: str) -> None: """ pass + @abstractmethod + def get_tenant(self, name: str) -> Tenant: + """Get a tenant. Raises an error if the tenant does not exist. + + Args: + tenant: The name of the tenant to get. + + """ + pass + class ServerAPI(BaseAPI, AdminAPI, Component): """An API instance that extends the relevant Base API methods by passing diff --git a/chromadb/api/client.py b/chromadb/api/client.py index d7c3c236455..e1996628d2d 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -19,7 +19,7 @@ from chromadb.telemetry.events import ClientStartEvent from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE from chromadb.api.models.Collection import Collection -from chromadb.types import Where, WhereDocument +from chromadb.types import Database, Tenant, Where, WhereDocument import chromadb.utils.embedding_functions as ef C = TypeVar("C", "SharedSystemClient", "Client", "AdminClient") @@ -405,10 +405,18 @@ def __init__(self, settings: Settings = Settings()) -> None: def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: return self._server.create_database(name=name, tenant=tenant) + @override + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + return self._server.get_database(name=name, tenant=tenant) + @override def create_tenant(self, name: str) -> None: return self._server.create_tenant(name=name) + @override + def get_tenant(self, name: str) -> Tenant: + return self._server.get_tenant(name=name) + @classmethod @override def from_system( diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index f75a8648f17..a9539138207 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -8,6 +8,7 @@ from overrides import override import chromadb.errors as errors +from chromadb.types import Database, Tenant import chromadb.utils.embedding_functions as ef from chromadb.api import ServerAPI from chromadb.api.models.Collection import Collection @@ -148,6 +149,23 @@ def create_database( ) raise_chroma_error(resp) + @override + def get_database( + self, + name: str, + tenant: str = DEFAULT_TENANT, + ) -> Database: + """Returns a database""" + resp = self._session.get( + self._api_url + "/databases/" + name, + params={"tenant": tenant}, + ) + raise_chroma_error(resp) + resp_json = resp.json() + return Database( + id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"] + ) + @override def create_tenant(self, name: str) -> None: resp = self._session.post( @@ -156,6 +174,15 @@ def create_tenant(self, name: str) -> None: ) raise_chroma_error(resp) + @override + def get_tenant(self, name: str) -> Tenant: + resp = self._session.get( + self._api_url + "/tenants/" + name, + ) + raise_chroma_error(resp) + resp_json = resp.json() + return Tenant(name=resp_json["name"]) + @override def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 8cb95d53080..741afea807b 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -105,6 +105,10 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: tenant=tenant, ) + @override + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database: + return self._sysdb.get_database(name=name, tenant=tenant) + @override def create_tenant(self, name: str) -> None: if len(name) < 3: @@ -114,6 +118,10 @@ def create_tenant(self, name: str) -> None: name=name, ) + @override + def get_tenant(self, name: str) -> t.Tenant: + return self._sysdb.get_tenant(name=name) + # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index 967be7c1194..e1b279528f0 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -20,17 +20,21 @@ DeleteSegmentRequest, GetCollectionsRequest, GetCollectionsResponse, + GetDatabaseRequest, GetSegmentsRequest, + GetTenantRequest, UpdateCollectionRequest, UpdateSegmentRequest, ) from chromadb.proto.coordinator_pb2_grpc import SysDBStub from chromadb.types import ( Collection, + Database, Metadata, OptionalArgument, Segment, SegmentScope, + Tenant, Unspecified, UpdateMetadata, ) @@ -82,6 +86,18 @@ def create_database( if response.status.code == 409: raise UniqueConstraintError() + @overrides + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + request = GetDatabaseRequest(name=name, tenant=tenant) + response = self._sys_db_stub.GetDatabase(request) + if response.status.code == 404: + raise NotFoundError() + return Database( + id=UUID(hex=response.database.id), + name=response.database.name, + tenant=response.database.tenant, + ) + @overrides def create_tenant(self, name: str) -> None: request = CreateTenantRequest(name=name) @@ -89,6 +105,16 @@ def create_tenant(self, name: str) -> None: if response.status.code == 409: raise UniqueConstraintError() + @overrides + def get_tenant(self, name: str) -> Tenant: + request = GetTenantRequest(name=name) + response = self._sys_db_stub.GetTenant(request) + if response.status.code == 404: + raise NotFoundError() + return Tenant( + name=response.tenant.name, + ) + @overrides def create_segment(self, segment: Segment) -> None: proto_segment = to_proto_segment(segment) diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index f308dd7714c..b050602720a 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -22,8 +22,12 @@ DeleteSegmentRequest, GetCollectionsRequest, GetCollectionsResponse, + GetDatabaseRequest, + GetDatabaseResponse, GetSegmentsRequest, GetSegmentsResponse, + GetTenantRequest, + GetTenantResponse, UpdateCollectionRequest, UpdateSegmentRequest, ) @@ -47,6 +51,7 @@ class GrpcMockSysDB(SysDBServicer, Component): _tenants_to_databases_to_collections: Dict[ str, Dict[str, Dict[str, Collection]] ] = {} + _tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {} def __init__(self, system: System): self._server_port = system.settings.require("chroma_server_grpc_port") @@ -73,6 +78,8 @@ def reset_state(self) -> None: # Create defaults self._tenants_to_databases_to_collections["default"] = {} self._tenants_to_databases_to_collections["default"]["default"] = {} + self._tenants_to_database_to_id["default"] = {} + self._tenants_to_database_to_id["default"]["default"] = UUID(int=0) return super().reset_state() @overrides(check_signature=False) @@ -92,8 +99,29 @@ def CreateDatabase( ) ) self._tenants_to_databases_to_collections[tenant][database] = {} + self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id) return proto.ChromaResponse(status=proto.Status(code=200)) + @overrides(check_signature=False) + def GetDatabase( + self, request: GetDatabaseRequest, context: grpc.ServicerContext + ) -> GetDatabaseResponse: + tenant = request.tenant + database = request.name + if tenant not in self._tenants_to_databases_to_collections: + return GetDatabaseResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + if database not in self._tenants_to_databases_to_collections[tenant]: + return GetDatabaseResponse( + status=proto.Status(code=404, reason=f"Database {database} not found") + ) + id = self._tenants_to_database_to_id[tenant][database] + return GetDatabaseResponse( + status=proto.Status(code=200), + database=proto.Database(id=id.hex, name=database, tenant=tenant), + ) + @overrides(check_signature=False) def CreateTenant( self, request: CreateDatabaseRequest, context: grpc.ServicerContext @@ -104,8 +132,23 @@ def CreateTenant( status=proto.Status(code=409, reason=f"Tenant {tenant} already exists") ) self._tenants_to_databases_to_collections[tenant] = {} + self._tenants_to_database_to_id[tenant] = {} return proto.ChromaResponse(status=proto.Status(code=200)) + @overrides(check_signature=False) + def GetTenant( + self, request: GetTenantRequest, context: grpc.ServicerContext + ) -> GetTenantResponse: + tenant = request.name + if tenant not in self._tenants_to_databases_to_collections: + return GetTenantResponse( + status=proto.Status(code=404, reason=f"Tenant {tenant} not found") + ) + return GetTenantResponse( + status=proto.Status(code=200), + tenant=proto.Tenant(name=tenant), + ) + # We are forced to use check_signature=False because the generated proto code # does not have type annotations for the request and response objects. # TODO: investigate generating types for the request and response objects diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index dbf84f477f8..7009f78a281 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -16,11 +16,13 @@ from chromadb.db.system import SysDB from chromadb.ingest import CollectionAssignmentPolicy, Producer from chromadb.types import ( + Database, OptionalArgument, Segment, Metadata, Collection, SegmentScope, + Tenant, Unspecified, UpdateMetadata, ) @@ -72,6 +74,30 @@ def create_database( f"Database {name} already exists for tenant {tenant}" ) from e + @override + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + with self.tx() as cur: + databases = Table("databases") + q = ( + self.querybuilder() + .from_(databases) + .select(databases.id, databases.name) + .where(databases.name == ParameterValue(name)) + .where(databases.tenant_id == ParameterValue(tenant)) + ) + sql, params = get_sql(q, self.parameter_format()) + row = cur.execute(sql, params).fetchone() + if not row: + raise NotFoundError(f"Database {name} not found for tenant {tenant}") + if row[0] is None: + raise NotFoundError(f"Database {name} not found for tenant {tenant}") + id: UUID = cast(UUID, self.uuid_from_db(row[0])) + return Database( + id=id, + name=row[1], + tenant=tenant, + ) + @override def create_tenant(self, name: str) -> None: with self.tx() as cur: @@ -88,6 +114,22 @@ def create_tenant(self, name: str) -> None: except self.unique_constraint_error() as e: raise UniqueConstraintError(f"Tenant {name} already exists") from e + @override + def get_tenant(self, name: str) -> Tenant: + with self.tx() as cur: + tenants = Table("tenants") + q = ( + self.querybuilder() + .from_(tenants) + .select(tenants.id) + .where(tenants.id == ParameterValue(name)) + ) + sql, params = get_sql(q, self.parameter_format()) + row = cur.execute(sql, params).fetchone() + if not row: + raise NotFoundError(f"Tenant {name} not found") + return Tenant(name=name) + @override def create_segment(self, segment: Segment) -> None: with self.tx() as cur: diff --git a/chromadb/db/system.py b/chromadb/db/system.py index 3fd76feff36..17606b740b6 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -3,6 +3,8 @@ from uuid import UUID from chromadb.types import ( Collection, + Database, + Tenant, Metadata, Segment, SegmentScope, @@ -24,12 +26,23 @@ def create_database( already exists.""" pass + @abstractmethod + def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: + """Get a database by name and tenant. Raises an Error if the Database does not + exist.""" + pass + @abstractmethod def create_tenant(self, name: str) -> None: """Create a new tenant in the System database. The name must be unique. Raises an Error if the Tenant already exists.""" pass + @abstractmethod + def get_tenant(self, name: str) -> Tenant: + """Get a tenant by name. Raises an Error if the Tenant does not exist.""" + pass + @abstractmethod def create_segment(self, segment: Segment) -> None: """Create a new segment in the System database. Raises an Error if the ID diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql index be044a8216b..838114e1dff 100644 --- a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -1,16 +1,16 @@ -CREATE TABLE tenants ( -- todo: make this idempotent by checking if table exists by using CREATE TABLE IF NOT EXISTS +CREATE TABLE IF NOT EXISTS tenants ( id TEXT PRIMARY KEY, - UNIQUE (id) -- Maybe not needed since we want to support slug ids + UNIQUE (id) ); -CREATE TABLE databases ( +CREATE TABLE IF NOT EXISTS databases ( id TEXT PRIMARY KEY, -- unique globally name TEXT NOT NULL, -- unique per tenant tenant_id TEXT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, UNIQUE (tenant_id, name) -- Ensure that a tenant has only one database with a given name ); -CREATE TABLE collections_tmp ( +CREATE TABLE IF NOT EXISTS collections_tmp ( id TEXT PRIMARY KEY, -- unique globally name TEXT NOT NULL, -- unique per database topic TEXT NOT NULL, @@ -21,7 +21,7 @@ CREATE TABLE collections_tmp ( -- Create default tenant and database INSERT INTO tenants (id) VALUES ('default'); -- The default tenant id is 'default' others are UUIDs -INSERT INTO databases (id, name, tenant_id) VALUES ('default', 'default', 'default'); +INSERT INTO databases (id, name, tenant_id) VALUES ('00000000-0000-0000-0000-000000000000', 'default', 'default'); INSERT INTO collections_tmp (id, name, topic, dimension, database_id) SELECT id, name, topic, dimension, 'default' FROM collections; diff --git a/chromadb/proto/chroma_pb2.py b/chromadb/proto/chroma_pb2.py index bd069cc74f4..3644a70c74d 100644 --- a/chromadb/proto/chroma_pb2.py +++ b/chromadb/proto/chroma_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"&\n\x06Status\x12\x0e\n\x06reason\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\"0\n\x0e\x43hromaResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.chroma.Status\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"\x97\x01\n\nCollection\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05topic\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x05 \x01(\x05H\x01\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimension\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"&\n\x06Status\x12\x0e\n\x06reason\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\"0\n\x0e\x43hromaResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.chroma.Status\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"\x97\x01\n\nCollection\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05topic\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x05 \x01(\x05H\x01\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimension\"4\n\x08\x44\x61tabase\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"\x16\n\x06Tenant\x12\x0c\n\x04name\x18\x01 \x01(\t\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -22,12 +22,12 @@ DESCRIPTOR._options = None _UPDATEMETADATA_METADATAENTRY._options = None _UPDATEMETADATA_METADATAENTRY._serialized_options = b'8\001' - _globals['_OPERATION']._serialized_start=1650 - _globals['_OPERATION']._serialized_end=1706 - _globals['_SCALARENCODING']._serialized_start=1708 - _globals['_SCALARENCODING']._serialized_end=1748 - _globals['_SEGMENTSCOPE']._serialized_start=1750 - _globals['_SEGMENTSCOPE']._serialized_end=1790 + _globals['_OPERATION']._serialized_start=1728 + _globals['_OPERATION']._serialized_end=1784 + _globals['_SCALARENCODING']._serialized_start=1786 + _globals['_SCALARENCODING']._serialized_end=1826 + _globals['_SEGMENTSCOPE']._serialized_start=1828 + _globals['_SEGMENTSCOPE']._serialized_end=1868 _globals['_STATUS']._serialized_start=39 _globals['_STATUS']._serialized_end=77 _globals['_CHROMARESPONSE']._serialized_start=79 @@ -38,32 +38,36 @@ _globals['_SEGMENT']._serialized_end=419 _globals['_COLLECTION']._serialized_start=422 _globals['_COLLECTION']._serialized_end=573 - _globals['_UPDATEMETADATAVALUE']._serialized_start=575 - _globals['_UPDATEMETADATAVALUE']._serialized_end=673 - _globals['_UPDATEMETADATA']._serialized_start=676 - _globals['_UPDATEMETADATA']._serialized_end=826 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=750 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=826 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=829 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=1010 - _globals['_VECTOREMBEDDINGRECORD']._serialized_start=1012 - _globals['_VECTOREMBEDDINGRECORD']._serialized_end=1095 - _globals['_VECTORQUERYRESULT']._serialized_start=1097 - _globals['_VECTORQUERYRESULT']._serialized_end=1210 - _globals['_VECTORQUERYRESULTS']._serialized_start=1212 - _globals['_VECTORQUERYRESULTS']._serialized_end=1276 - _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1278 - _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1318 - _globals['_GETVECTORSREQUEST']._serialized_start=1320 - _globals['_GETVECTORSREQUEST']._serialized_end=1372 - _globals['_GETVECTORSRESPONSE']._serialized_start=1374 - _globals['_GETVECTORSRESPONSE']._serialized_end=1442 - _globals['_QUERYVECTORSREQUEST']._serialized_start=1445 - _globals['_QUERYVECTORSREQUEST']._serialized_end=1579 - _globals['_QUERYVECTORSRESPONSE']._serialized_start=1581 - _globals['_QUERYVECTORSRESPONSE']._serialized_end=1648 - _globals['_SEGMENTSERVER']._serialized_start=1793 - _globals['_SEGMENTSERVER']._serialized_end=1941 - _globals['_VECTORREADER']._serialized_start=1944 - _globals['_VECTORREADER']._serialized_end=2106 + _globals['_DATABASE']._serialized_start=575 + _globals['_DATABASE']._serialized_end=627 + _globals['_TENANT']._serialized_start=629 + _globals['_TENANT']._serialized_end=651 + _globals['_UPDATEMETADATAVALUE']._serialized_start=653 + _globals['_UPDATEMETADATAVALUE']._serialized_end=751 + _globals['_UPDATEMETADATA']._serialized_start=754 + _globals['_UPDATEMETADATA']._serialized_end=904 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=828 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=904 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=907 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=1088 + _globals['_VECTOREMBEDDINGRECORD']._serialized_start=1090 + _globals['_VECTOREMBEDDINGRECORD']._serialized_end=1173 + _globals['_VECTORQUERYRESULT']._serialized_start=1175 + _globals['_VECTORQUERYRESULT']._serialized_end=1288 + _globals['_VECTORQUERYRESULTS']._serialized_start=1290 + _globals['_VECTORQUERYRESULTS']._serialized_end=1354 + _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1356 + _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1396 + _globals['_GETVECTORSREQUEST']._serialized_start=1398 + _globals['_GETVECTORSREQUEST']._serialized_end=1450 + _globals['_GETVECTORSRESPONSE']._serialized_start=1452 + _globals['_GETVECTORSRESPONSE']._serialized_end=1520 + _globals['_QUERYVECTORSREQUEST']._serialized_start=1523 + _globals['_QUERYVECTORSREQUEST']._serialized_end=1657 + _globals['_QUERYVECTORSRESPONSE']._serialized_start=1659 + _globals['_QUERYVECTORSRESPONSE']._serialized_end=1726 + _globals['_SEGMENTSERVER']._serialized_start=1871 + _globals['_SEGMENTSERVER']._serialized_end=2019 + _globals['_VECTORREADER']._serialized_start=2022 + _globals['_VECTORREADER']._serialized_end=2184 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi index 733cae0a273..386f253e776 100644 --- a/chromadb/proto/chroma_pb2.pyi +++ b/chromadb/proto/chroma_pb2.pyi @@ -85,6 +85,22 @@ class Collection(_message.Message): dimension: int def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ...) -> None: ... +class Database(_message.Message): + __slots__ = ["id", "name", "tenant"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + tenant: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... + +class Tenant(_message.Message): + __slots__ = ["name"] + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + class UpdateMetadataValue(_message.Message): __slots__ = ["string_value", "int_value", "float_value"] STRING_VALUE_FIELD_NUMBER: _ClassVar[int] diff --git a/chromadb/proto/coordinator_pb2.py b/chromadb/proto/coordinator_pb2.py index 7f865bb217f..42039c2d23f 100644 --- a/chromadb/proto/coordinator_pb2.py +++ b/chromadb/proto/coordinator_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"A\n\x15\x43reateDatabaseRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"#\n\x13\x43reateTenantRequest\x12\x0c\n\x04name\x18\x02 \x01(\t\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xe5\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x06 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x07 \x01(\tB\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"G\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\"\x8b\x01\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x04 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x05 \x01(\tB\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xc8\x06\n\x05SysDB\x12I\n\x0e\x43reateDatabase\x12\x1d.chroma.CreateDatabaseRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12\x45\n\x0c\x43reateTenant\x12\x1b.chroma.CreateTenantRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"A\n\x15\x43reateDatabaseRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"2\n\x12GetDatabaseRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\"Y\n\x13GetDatabaseResponse\x12\"\n\x08\x64\x61tabase\x18\x01 \x01(\x0b\x32\x10.chroma.Database\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"#\n\x13\x43reateTenantRequest\x12\x0c\n\x04name\x18\x02 \x01(\t\" \n\x10GetTenantRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"S\n\x11GetTenantResponse\x12\x1e\n\x06tenant\x18\x01 \x01(\x0b\x32\x0e.chroma.Tenant\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xe5\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x06 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x07 \x01(\tB\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"G\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\"\x8b\x01\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x04 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x05 \x01(\tB\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xd6\x07\n\x05SysDB\x12I\n\x0e\x43reateDatabase\x12\x1d.chroma.CreateDatabaseRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetDatabase\x12\x1a.chroma.GetDatabaseRequest\x1a\x1b.chroma.GetDatabaseResponse\"\x00\x12\x45\n\x0c\x43reateTenant\x12\x1b.chroma.CreateTenantRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12\x42\n\tGetTenant\x12\x18.chroma.GetTenantRequest\x1a\x19.chroma.GetTenantResponse\"\x00\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -24,30 +24,38 @@ DESCRIPTOR._options = None _globals['_CREATEDATABASEREQUEST']._serialized_start=102 _globals['_CREATEDATABASEREQUEST']._serialized_end=167 - _globals['_CREATETENANTREQUEST']._serialized_start=169 - _globals['_CREATETENANTREQUEST']._serialized_end=204 - _globals['_CREATESEGMENTREQUEST']._serialized_start=206 - _globals['_CREATESEGMENTREQUEST']._serialized_end=262 - _globals['_DELETESEGMENTREQUEST']._serialized_start=264 - _globals['_DELETESEGMENTREQUEST']._serialized_end=298 - _globals['_GETSEGMENTSREQUEST']._serialized_start=301 - _globals['_GETSEGMENTSREQUEST']._serialized_end=495 - _globals['_GETSEGMENTSRESPONSE']._serialized_start=497 - _globals['_GETSEGMENTSRESPONSE']._serialized_end=585 - _globals['_UPDATESEGMENTREQUEST']._serialized_start=588 - _globals['_UPDATESEGMENTREQUEST']._serialized_end=838 - _globals['_CREATECOLLECTIONREQUEST']._serialized_start=841 - _globals['_CREATECOLLECTIONREQUEST']._serialized_end=1070 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=1072 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1187 - _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1189 - _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1260 - _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1263 - _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1402 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1404 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1501 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1504 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1726 - _globals['_SYSDB']._serialized_start=1729 - _globals['_SYSDB']._serialized_end=2569 + _globals['_GETDATABASEREQUEST']._serialized_start=169 + _globals['_GETDATABASEREQUEST']._serialized_end=219 + _globals['_GETDATABASERESPONSE']._serialized_start=221 + _globals['_GETDATABASERESPONSE']._serialized_end=310 + _globals['_CREATETENANTREQUEST']._serialized_start=312 + _globals['_CREATETENANTREQUEST']._serialized_end=347 + _globals['_GETTENANTREQUEST']._serialized_start=349 + _globals['_GETTENANTREQUEST']._serialized_end=381 + _globals['_GETTENANTRESPONSE']._serialized_start=383 + _globals['_GETTENANTRESPONSE']._serialized_end=466 + _globals['_CREATESEGMENTREQUEST']._serialized_start=468 + _globals['_CREATESEGMENTREQUEST']._serialized_end=524 + _globals['_DELETESEGMENTREQUEST']._serialized_start=526 + _globals['_DELETESEGMENTREQUEST']._serialized_end=560 + _globals['_GETSEGMENTSREQUEST']._serialized_start=563 + _globals['_GETSEGMENTSREQUEST']._serialized_end=757 + _globals['_GETSEGMENTSRESPONSE']._serialized_start=759 + _globals['_GETSEGMENTSRESPONSE']._serialized_end=847 + _globals['_UPDATESEGMENTREQUEST']._serialized_start=850 + _globals['_UPDATESEGMENTREQUEST']._serialized_end=1100 + _globals['_CREATECOLLECTIONREQUEST']._serialized_start=1103 + _globals['_CREATECOLLECTIONREQUEST']._serialized_end=1332 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=1334 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1449 + _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1451 + _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1522 + _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1525 + _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1664 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1666 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1763 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1766 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1988 + _globals['_SYSDB']._serialized_start=1991 + _globals['_SYSDB']._serialized_end=2973 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/coordinator_pb2.pyi b/chromadb/proto/coordinator_pb2.pyi index 047eccba5f2..81545e4e283 100644 --- a/chromadb/proto/coordinator_pb2.pyi +++ b/chromadb/proto/coordinator_pb2.pyi @@ -17,12 +17,42 @@ class CreateDatabaseRequest(_message.Message): tenant: str def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... +class GetDatabaseRequest(_message.Message): + __slots__ = ["name", "tenant"] + NAME_FIELD_NUMBER: _ClassVar[int] + TENANT_FIELD_NUMBER: _ClassVar[int] + name: str + tenant: str + def __init__(self, name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... + +class GetDatabaseResponse(_message.Message): + __slots__ = ["database", "status"] + DATABASE_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + database: _chroma_pb2.Database + status: _chroma_pb2.Status + def __init__(self, database: _Optional[_Union[_chroma_pb2.Database, _Mapping]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + class CreateTenantRequest(_message.Message): __slots__ = ["name"] NAME_FIELD_NUMBER: _ClassVar[int] name: str def __init__(self, name: _Optional[str] = ...) -> None: ... +class GetTenantRequest(_message.Message): + __slots__ = ["name"] + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class GetTenantResponse(_message.Message): + __slots__ = ["tenant", "status"] + TENANT_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + tenant: _chroma_pb2.Tenant + status: _chroma_pb2.Status + def __init__(self, tenant: _Optional[_Union[_chroma_pb2.Tenant, _Mapping]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + class CreateSegmentRequest(_message.Message): __slots__ = ["segment"] SEGMENT_FIELD_NUMBER: _ClassVar[int] diff --git a/chromadb/proto/coordinator_pb2_grpc.py b/chromadb/proto/coordinator_pb2_grpc.py index f164b41f19a..117c568c715 100644 --- a/chromadb/proto/coordinator_pb2_grpc.py +++ b/chromadb/proto/coordinator_pb2_grpc.py @@ -21,11 +21,21 @@ def __init__(self, channel): request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, ) + self.GetDatabase = channel.unary_unary( + "/chroma.SysDB/GetDatabase", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseResponse.FromString, + ) self.CreateTenant = channel.unary_unary( "/chroma.SysDB/CreateTenant", request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.SerializeToString, response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, ) + self.GetTenant = channel.unary_unary( + "/chroma.SysDB/GetTenant", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantResponse.FromString, + ) self.CreateSegment = channel.unary_unary( "/chroma.SysDB/CreateSegment", request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, @@ -82,12 +92,24 @@ def CreateDatabase(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def GetDatabase(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def CreateTenant(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def GetTenant(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def CreateSegment(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -150,11 +172,21 @@ def add_SysDBServicer_to_server(servicer, server): request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.FromString, response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, ), + "GetDatabase": grpc.unary_unary_rpc_method_handler( + servicer.GetDatabase, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseRequest.FromString, + response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseResponse.SerializeToString, + ), "CreateTenant": grpc.unary_unary_rpc_method_handler( servicer.CreateTenant, request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.FromString, response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, ), + "GetTenant": grpc.unary_unary_rpc_method_handler( + servicer.GetTenant, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantRequest.FromString, + response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantResponse.SerializeToString, + ), "CreateSegment": grpc.unary_unary_rpc_method_handler( servicer.CreateSegment, request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.FromString, @@ -240,6 +272,35 @@ def CreateDatabase( metadata, ) + @staticmethod + def GetDatabase( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/GetDatabase", + chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseRequest.SerializeToString, + chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def CreateTenant( request, @@ -269,6 +330,35 @@ def CreateTenant( metadata, ) + @staticmethod + def GetTenant( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/GetTenant", + chromadb_dot_proto_dot_coordinator__pb2.GetTenantRequest.SerializeToString, + chromadb_dot_proto_dot_coordinator__pb2.GetTenantResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def CreateSegment( request, diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 26385578bdb..b08838e1dff 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -532,6 +532,38 @@ def test_create_database_with_tenants(sysdb: SysDB) -> None: assert len(result) == 0 +def test_get_database_with_tenants(sysdb: SysDB) -> None: + sysdb.reset_state() + + # Create a new tenant + sysdb.create_tenant(name="tenant1") + + # Get the tenant and check that it exists + result = sysdb.get_tenant(name="tenant1") + assert result["name"] == "tenant1" + + # Get a tenant that does not exist and expect an error + with pytest.raises(NotFoundError): + sysdb.get_tenant(name="tenant2") + + # Create a new database within this tenant + sysdb.create_database(id=uuid.uuid4(), name="new_database", tenant="tenant1") + + # Get the database and check that it exists + result = sysdb.get_database(name="new_database", tenant="tenant1") + assert result["name"] == "new_database" + assert result["tenant"] == "tenant1" + + # Get a database that does not exist in a tenant that does exist and expect an error + with pytest.raises(NotFoundError): + sysdb.get_database(name="new_database1", tenant="tenant1") + + # Get a database that does not exist in a tenant that does not exist and expect an + # error + with pytest.raises(NotFoundError): + sysdb.get_database(name="new_database1", tenant="tenant2") + + # endregion # region Segment tests diff --git a/chromadb/types.py b/chromadb/types.py index 713cab7757c..ed74b4f5e20 100644 --- a/chromadb/types.py +++ b/chromadb/types.py @@ -29,6 +29,16 @@ class Collection(TypedDict): dimension: Optional[int] +class Database(TypedDict): + id: UUID + name: str + tenant: str + + +class Tenant(TypedDict): + name: str + + class Segment(TypedDict): id: UUID type: NamespacedName diff --git a/idl/chromadb/proto/chroma.proto b/idl/chromadb/proto/chroma.proto index 7b1f10f18ea..5aaff218176 100644 --- a/idl/chromadb/proto/chroma.proto +++ b/idl/chromadb/proto/chroma.proto @@ -55,6 +55,16 @@ message Collection { optional int32 dimension = 5; } +message Database { + string id = 1; + string name = 2; + string tenant = 3; +} + +message Tenant { + string name = 1; +} + message UpdateMetadataValue { oneof value { string string_value = 1; diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index c0ecb134fa9..0871f3f3c52 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -11,10 +11,29 @@ message CreateDatabaseRequest { string tenant = 3; } +message GetDatabaseRequest { + string name = 1; + string tenant = 2; +} + +message GetDatabaseResponse { + Database database = 1; + Status status = 2; +} + message CreateTenantRequest { string name = 2; // Names are globally unique } +message GetTenantRequest { + string name = 1; +} + +message GetTenantResponse { + Tenant tenant = 1; + Status status = 2; +} + message CreateSegmentRequest { Segment segment = 1; } @@ -101,7 +120,9 @@ message UpdateCollectionRequest { service SysDB { rpc CreateDatabase(CreateDatabaseRequest) returns (ChromaResponse) {} + rpc GetDatabase(GetDatabaseRequest) returns (GetDatabaseResponse) {} rpc CreateTenant(CreateTenantRequest) returns (ChromaResponse) {} + rpc GetTenant(GetTenantRequest) returns (GetTenantResponse) {} rpc CreateSegment(CreateSegmentRequest) returns (ChromaResponse) {} rpc DeleteSegment(DeleteSegmentRequest) returns (ChromaResponse) {} rpc GetSegments(GetSegmentsRequest) returns (GetSegmentsResponse) {} From e14bb2170bd300dd855b3a5d05ba2e12e31c1c08 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 08:31:29 -0700 Subject: [PATCH 09/22] Add validation --- chromadb/api/__init__.py | 16 +++--------- chromadb/api/client.py | 26 ++++++++++++++++--- .../test_collections_with_database_tenant.py | 13 ++++------ 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 01cdcf943ce..f2a8099e4fc 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -396,21 +396,13 @@ class ClientAPI(BaseAPI, ABC): database: str @abstractmethod - def set_database(self, database: str) -> None: - """Set the database for the client. - - Args: - database: The database to set. - - """ - pass - - @abstractmethod - def set_tenant(self, tenant: str) -> None: - """Set the tenant for the client. + def set_tenant_and_database(self, tenant: str, database: str) -> None: + """Set the tenant and database for the client. Raises an error if the tenant or + database does not exist. Args: tenant: The tenant to set. + database: The database to set. """ pass diff --git a/chromadb/api/client.py b/chromadb/api/client.py index e1996628d2d..a80510131ea 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -122,6 +122,8 @@ class Client(SharedSystemClient, ClientAPI): database: str = DEFAULT_DATABASE _server: ServerAPI + # An internal admin client for verifying that databases and tenants exist + _admin_client: AdminAPI # region Initialization def __init__( @@ -141,6 +143,9 @@ def __init__( telemetry_client = self._system.instance(Telemetry) telemetry_client.capture(ClientStartEvent()) + # Create an admin client for verifying that databases and tenants exist + self._admin_client = AdminClient.from_system(self._system) + @classmethod @override def from_system( @@ -384,12 +389,25 @@ def max_batch_size(self) -> int: # region ClientAPI Methods @override - def set_database(self, database: str) -> None: + def set_tenant_and_database(self, tenant: str, database: str) -> None: + self._validate_tenant_database(database, self.tenant) + self.tenant = tenant self.database = database - @override - def set_tenant(self, tenant: str) -> None: - self.tenant = tenant + def _validate_tenant_database(self, database: str, tenant: str) -> None: + try: + self._admin_client.get_tenant(name=tenant) + except Exception: + raise ValueError( + f"Could not connect to tenant {tenant}. Are you sure it exists?" + ) + + try: + self._admin_client.get_database(name=database, tenant=tenant) + except Exception: + raise ValueError( + f"Could not connect to database {database} for tenant {tenant}. Are you sure it exists?" + ) # endregion diff --git a/chromadb/test/property/test_collections_with_database_tenant.py b/chromadb/test/property/test_collections_with_database_tenant.py index 7b0b6ecc183..27832dda8a3 100644 --- a/chromadb/test/property/test_collections_with_database_tenant.py +++ b/chromadb/test/property/test_collections_with_database_tenant.py @@ -44,8 +44,7 @@ def initialize(self) -> None: self.tenant_to_database_to_model = {} self.curr_tenant = DEFAULT_TENANT self.curr_database = DEFAULT_DATABASE - self.api.set_tenant(DEFAULT_TENANT) - self.api.set_database(DEFAULT_DATABASE) + self.api.set_tenant_and_database(DEFAULT_TENANT, DEFAULT_DATABASE) self.tenant_to_database_to_model[self.curr_tenant] = {} self.tenant_to_database_to_model[self.curr_tenant][self.curr_database] = {} @@ -81,17 +80,15 @@ def create_database(self, name: str) -> MultipleResults[Tuple[str, str]]: @rule(database=databases) def set_database_and_tenant(self, database: Dict[str, str]) -> None: # Get a database and switch to the database and the tenant it belongs to - database_name = database[0] - tenant_name = database[1] - self.api.set_tenant(tenant_name) - self.api.set_database(database_name) + database_name = database[0] # type: ignore + tenant_name = database[1] # type: ignore + self.api.set_tenant_and_database(tenant_name, database_name) self.curr_database = database_name self.curr_tenant = tenant_name @rule(tenant=tenants) def set_tenant(self, tenant: str) -> None: - self.api.set_tenant(tenant) - self.api.set_database(DEFAULT_DATABASE) + self.api.set_tenant_and_database(tenant, DEFAULT_DATABASE) self.curr_tenant = tenant self.curr_database = DEFAULT_DATABASE From 12535e5498cc285e6cc1e41bf0a97eecd68480e1 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 08:32:22 -0700 Subject: [PATCH 10/22] Fix type --- .../test/property/test_collections_with_database_tenant.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chromadb/test/property/test_collections_with_database_tenant.py b/chromadb/test/property/test_collections_with_database_tenant.py index 27832dda8a3..4d95b2ec97a 100644 --- a/chromadb/test/property/test_collections_with_database_tenant.py +++ b/chromadb/test/property/test_collections_with_database_tenant.py @@ -78,10 +78,10 @@ def create_database(self, name: str) -> MultipleResults[Tuple[str, str]]: return multiple((name, self.curr_tenant)) @rule(database=databases) - def set_database_and_tenant(self, database: Dict[str, str]) -> None: + def set_database_and_tenant(self, database: Tuple[str, str]) -> None: # Get a database and switch to the database and the tenant it belongs to - database_name = database[0] # type: ignore - tenant_name = database[1] # type: ignore + database_name = database[0] + tenant_name = database[1] self.api.set_tenant_and_database(tenant_name, database_name) self.curr_database = database_name self.curr_tenant = tenant_name From 5baf290fd3af7389a751edcc5655728089b5b8d3 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 08:32:47 -0700 Subject: [PATCH 11/22] Fix type --- chromadb/db/mixins/sysdb.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index 7009f78a281..0bb5e710de1 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -67,8 +67,6 @@ def create_database( sql, params = get_sql(insert_database, self.parameter_format()) try: cur.execute(sql, params) - # TODO: tenant doesn't exist test - # TODO: implement unique constraint error lol... except self.unique_constraint_error() as e: raise UniqueConstraintError( f"Database {name} already exists for tenant {tenant}" From 6bcb1fd49528ab5f970396314688ab2d18b27fa7 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 09:15:48 -0700 Subject: [PATCH 12/22] Add to fastapi --- chromadb/api/client.py | 4 ++-- chromadb/server/fastapi/__init__.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/chromadb/api/client.py b/chromadb/api/client.py index a80510131ea..c7d9af971fc 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -390,11 +390,11 @@ def max_batch_size(self) -> int: @override def set_tenant_and_database(self, tenant: str, database: str) -> None: - self._validate_tenant_database(database, self.tenant) + self._validate_tenant_database(tenant=tenant, database=database) self.tenant = tenant self.database = database - def _validate_tenant_database(self, database: str, tenant: str) -> None: + def _validate_tenant_database(self, tenant: str, database: str) -> None: try: self._admin_client.get_tenant(name=tenant) except Exception: diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index d432fcf3028..6f679772504 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -39,6 +39,7 @@ import logging from chromadb.telemetry import ServerContext, Telemetry +from chromadb.types import Database, Tenant logger = logging.getLogger(__name__) @@ -142,6 +143,13 @@ def __init__(self, settings: Settings): response_model=None, ) + self.router.add_api_route( + "/api/v1/databases/{database}", + self.get_database, + methods=["GET"], + response_model=None, + ) + self.router.add_api_route( "/api/v1/tenants", self.create_tenant, @@ -149,6 +157,13 @@ def __init__(self, settings: Settings): response_model=None, ) + self.router.add_api_route( + "/api/v1/tenants/{tenant}", + self.get_tenant, + methods=["GET"], + response_model=None, + ) + self.router.add_api_route( "/api/v1/collections", self.list_collections, @@ -245,9 +260,15 @@ def create_database( ) -> None: return self._api.create_database(database.name, tenant) + def get_database(self, database: str, tenant: str = DEFAULT_TENANT) -> Database: + return self._api.get_database(database, tenant) + def create_tenant(self, tenant: CreateTenant) -> None: return self._api.create_tenant(tenant.name) + def get_tenant(self, tenant: str) -> Tenant: + return self._api.get_tenant(tenant) + def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> Sequence[Collection]: From 19a57f6a852131f0e22979d92fb659fb3a270198 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 10:28:45 -0700 Subject: [PATCH 13/22] Add set_database and change set_tenant to take database --- chromadb/api/__init__.py | 12 +++++++++++- chromadb/api/client.py | 7 ++++++- chromadb/test/client/test_database_tenant.py | 18 +++++++++--------- .../test_collections_with_database_tenant.py | 6 +++--- 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index f2a8099e4fc..aa8650e748c 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -396,7 +396,7 @@ class ClientAPI(BaseAPI, ABC): database: str @abstractmethod - def set_tenant_and_database(self, tenant: str, database: str) -> None: + def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None: """Set the tenant and database for the client. Raises an error if the tenant or database does not exist. @@ -407,6 +407,16 @@ def set_tenant_and_database(self, tenant: str, database: str) -> None: """ pass + @abstractmethod + def set_database(self, database: str) -> None: + """Set the database for the client. Raises an error if the database does not exist. + + Args: + database: The database to set. + + """ + pass + @staticmethod @abstractmethod def clear_system_cache() -> None: diff --git a/chromadb/api/client.py b/chromadb/api/client.py index c7d9af971fc..e33bc30a451 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -389,11 +389,16 @@ def max_batch_size(self) -> int: # region ClientAPI Methods @override - def set_tenant_and_database(self, tenant: str, database: str) -> None: + def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None: self._validate_tenant_database(tenant=tenant, database=database) self.tenant = tenant self.database = database + @override + def set_database(self, database: str) -> None: + self._validate_tenant_database(tenant=self.tenant, database=database) + self.database = database + def _validate_tenant_database(self, tenant: str, database: str) -> None: try: self._admin_client.get_tenant(name=tenant) diff --git a/chromadb/test/client/test_database_tenant.py b/chromadb/test/client/test_database_tenant.py index 55672d36aac..d8778c536f4 100644 --- a/chromadb/test/client/test_database_tenant.py +++ b/chromadb/test/client/test_database_tenant.py @@ -9,11 +9,11 @@ def test_database_tenant_collections(client: Client) -> None: admin_client.create_database("test_db") # Create collections in this new database - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") client.create_collection("collection", metadata={"database": "test_db"}) # Create collections in the default database - client.set_database("default") + client.set_tenant(tenant="default", database="default") client.create_collection("collection", metadata={"database": "default"}) # List collections in the default database @@ -23,37 +23,37 @@ def test_database_tenant_collections(client: Client) -> None: assert collections[0].metadata == {"database": "default"} # List collections in the new database - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") collections = client.list_collections() assert len(collections) == 1 assert collections[0].metadata == {"database": "test_db"} # Update the metadata in both databases to different values - client.set_database("default") + client.set_tenant(tenant="default", database="default") client.list_collections()[0].modify(metadata={"database": "default2"}) - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") client.list_collections()[0].modify(metadata={"database": "test_db2"}) # Validate that the metadata was updated - client.set_database("default") + client.set_tenant(tenant="default", database="default") collections = client.list_collections() assert len(collections) == 1 assert collections[0].metadata == {"database": "default2"} - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") collections = client.list_collections() assert len(collections) == 1 assert collections[0].metadata == {"database": "test_db2"} # Delete the collections and make sure databases are isolated - client.set_database("default") + client.set_tenant(tenant="default", database="default") client.delete_collection("collection") collections = client.list_collections() assert len(collections) == 0 - client.set_database("test_db") + client.set_tenant(tenant="default", database="test_db") collections = client.list_collections() assert len(collections) == 1 diff --git a/chromadb/test/property/test_collections_with_database_tenant.py b/chromadb/test/property/test_collections_with_database_tenant.py index 4d95b2ec97a..28ba14f092a 100644 --- a/chromadb/test/property/test_collections_with_database_tenant.py +++ b/chromadb/test/property/test_collections_with_database_tenant.py @@ -44,7 +44,7 @@ def initialize(self) -> None: self.tenant_to_database_to_model = {} self.curr_tenant = DEFAULT_TENANT self.curr_database = DEFAULT_DATABASE - self.api.set_tenant_and_database(DEFAULT_TENANT, DEFAULT_DATABASE) + self.api.set_tenant(DEFAULT_TENANT, DEFAULT_DATABASE) self.tenant_to_database_to_model[self.curr_tenant] = {} self.tenant_to_database_to_model[self.curr_tenant][self.curr_database] = {} @@ -82,13 +82,13 @@ def set_database_and_tenant(self, database: Tuple[str, str]) -> None: # Get a database and switch to the database and the tenant it belongs to database_name = database[0] tenant_name = database[1] - self.api.set_tenant_and_database(tenant_name, database_name) + self.api.set_tenant(tenant_name, database_name) self.curr_database = database_name self.curr_tenant = tenant_name @rule(tenant=tenants) def set_tenant(self, tenant: str) -> None: - self.api.set_tenant_and_database(tenant, DEFAULT_DATABASE) + self.api.set_tenant(tenant, DEFAULT_DATABASE) self.curr_tenant = tenant self.curr_database = DEFAULT_DATABASE From ce3d8222ffa4bec888457665b11089258e87fb23 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 10:53:00 -0700 Subject: [PATCH 14/22] cleanup --- chromadb/db/impl/sqlite.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index 10036d3c239..da4a5ab3ebd 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -138,9 +138,6 @@ def reset_state(self) -> None: for row in cur.fetchall(): cur.execute(f"DROP TABLE IF EXISTS {row[0]}") self._conn_pool.close() - # TODO: clean this up ---- I don't think its correct or needed - # if self._is_persistent: - # delete_file(self._db_file) self.start() super().reset_state() From f9b137b933769917935f311c077d09b419b88a34 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 11:32:47 -0700 Subject: [PATCH 15/22] fix bug --- chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql index 838114e1dff..7c6b7fba511 100644 --- a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -24,6 +24,6 @@ INSERT INTO tenants (id) VALUES ('default'); -- The default tenant id is 'defaul INSERT INTO databases (id, name, tenant_id) VALUES ('00000000-0000-0000-0000-000000000000', 'default', 'default'); INSERT INTO collections_tmp (id, name, topic, dimension, database_id) - SELECT id, name, topic, dimension, 'default' FROM collections; + SELECT id, name, topic, dimension, '00000000-0000-0000-0000-000000000000' FROM collections; DROP TABLE collections; ALTER TABLE collections_tmp RENAME TO collections; From 4b8987d8e16b2b8b3de3f750b1de4cbf521eb982 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 15:01:40 -0700 Subject: [PATCH 16/22] PR comments --- chromadb/__init__.py | 16 +++++++++++++++- chromadb/api/__init__.py | 10 ---------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 7aab0e3d79e..c49d89fa216 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -100,6 +100,10 @@ def EphemeralClient( """ Creates an in-memory instance of Chroma. This is useful for testing and development, but not recommended for production use. + + Args: + tenant: The tenant to use for this client. Defaults to the default tenant. + database: The database to use for this client. Defaults to the default database. """ settings.is_persistent = False @@ -118,6 +122,8 @@ def PersistentClient( Args: path: The directory to save Chroma's data to. Defaults to "./chroma". + tenant: The tenant to use for this client. Defaults to the default tenant. + database: The database to use for this client. Defaults to the default database. """ settings.persist_directory = path settings.is_persistent = True @@ -144,6 +150,8 @@ def HttpClient( port: The port of the Chroma server. Defaults to "8000". ssl: Whether to use SSL to connect to the Chroma server. Defaults to False. headers: A dictionary of headers to send to the Chroma server. Defaults to {}. + tenant: The tenant to use for this client. Defaults to the default tenant. + database: The database to use for this client. Defaults to the default database. """ settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI" @@ -164,6 +172,12 @@ def Client( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> ClientAPI: - """Return a running chroma.API instance""" + """ + Return a running chroma.API instance + + tenant: The tenant to use for this client. Defaults to the default tenant. + database: The database to use for this client. Defaults to the default database. + + """ return ClientCreator(tenant=tenant, database=database, settings=settings) diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index aa8650e748c..ae39b7b13f1 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -515,16 +515,6 @@ def get_or_create_collection( ) -> Collection: pass - @abstractmethod - @override - def _modify( - self, - id: UUID, - new_name: Optional[str] = None, - new_metadata: Optional[CollectionMetadata] = None, - ) -> None: - pass - @abstractmethod @override def delete_collection( From f56753a1733634d28b5717f32cc55db00213f921 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 15:02:07 -0700 Subject: [PATCH 17/22] PR comments --- chromadb/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chromadb/config.py b/chromadb/config.py index eb575696049..366595ea367 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -75,8 +75,8 @@ "chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl", } -DEFAULT_TENANT = "default" -DEFAULT_DATABASE = "default" +DEFAULT_TENANT = "default_tenant" +DEFAULT_DATABASE = "default_database" class Settings(BaseSettings): # type: ignore From fdd13fd8a27844b01ec614e1cd80daf0b36c428b Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 15:12:36 -0700 Subject: [PATCH 18/22] PR Comments - Migration idempotency and validate tenant/datatbase --- chromadb/api/client.py | 10 ++++------ .../sysdb/00004-tenants-databases.sqlite.sql | 6 +++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/chromadb/api/client.py b/chromadb/api/client.py index e33bc30a451..de68aef7ce4 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, Optional, Sequence, TypeVar +from typing import ClassVar, Dict, Optional, Sequence from uuid import UUID from overrides import override @@ -22,8 +22,6 @@ from chromadb.types import Database, Tenant, Where, WhereDocument import chromadb.utils.embedding_functions as ef -C = TypeVar("C", "SharedSystemClient", "Client", "AdminClient") - class SharedSystemClient: _identifer_to_system: ClassVar[Dict[str, System]] = {} @@ -135,6 +133,9 @@ def __init__( super().__init__(settings=settings) self.tenant = tenant self.database = database + # Create an admin client for verifying that databases and tenants exist + self._admin_client = AdminClient.from_system(self._system) + self._validate_tenant_database(tenant=tenant, database=database) # Get the root system component we want to interact with self._server = self._system.instance(ServerAPI) @@ -143,9 +144,6 @@ def __init__( telemetry_client = self._system.instance(Telemetry) telemetry_client.capture(ClientStartEvent()) - # Create an admin client for verifying that databases and tenants exist - self._admin_client = AdminClient.from_system(self._system) - @classmethod @override def from_system( diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql index 7c6b7fba511..b117c2547e9 100644 --- a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -20,10 +20,10 @@ CREATE TABLE IF NOT EXISTS collections_tmp ( ); -- Create default tenant and database -INSERT INTO tenants (id) VALUES ('default'); -- The default tenant id is 'default' others are UUIDs -INSERT INTO databases (id, name, tenant_id) VALUES ('00000000-0000-0000-0000-000000000000', 'default', 'default'); +INSERT OR REPLACE INTO tenants (id) VALUES ('default'); -- The default tenant id is 'default' others are UUIDs +INSERT OR REPLACE INTO databases (id, name, tenant_id) VALUES ('00000000-0000-0000-0000-000000000000', 'default', 'default'); -INSERT INTO collections_tmp (id, name, topic, dimension, database_id) +INSERT OR REPLACE INTO collections_tmp (id, name, topic, dimension, database_id) SELECT id, name, topic, dimension, '00000000-0000-0000-0000-000000000000' FROM collections; DROP TABLE collections; ALTER TABLE collections_tmp RENAME TO collections; From a4509e53123292da2370e2a30cfeefc0a21dc8dd Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 15:13:21 -0700 Subject: [PATCH 19/22] PR Comments - Migration idempotency and validate tenant/datatbase --- chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql index b117c2547e9..43372bf97a8 100644 --- a/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql +++ b/chromadb/migrations/sysdb/00004-tenants-databases.sqlite.sql @@ -20,8 +20,8 @@ CREATE TABLE IF NOT EXISTS collections_tmp ( ); -- Create default tenant and database -INSERT OR REPLACE INTO tenants (id) VALUES ('default'); -- The default tenant id is 'default' others are UUIDs -INSERT OR REPLACE INTO databases (id, name, tenant_id) VALUES ('00000000-0000-0000-0000-000000000000', 'default', 'default'); +INSERT OR REPLACE INTO tenants (id) VALUES ('default_tenant'); -- The default tenant id is 'default_tenant' others are UUIDs +INSERT OR REPLACE INTO databases (id, name, tenant_id) VALUES ('00000000-0000-0000-0000-000000000000', 'default_database', 'default_tenant'); INSERT OR REPLACE INTO collections_tmp (id, name, topic, dimension, database_id) SELECT id, name, topic, dimension, '00000000-0000-0000-0000-000000000000' FROM collections; From 5b5ea3485d40350db171e8fd50d47cc79dee53f5 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 15:42:56 -0700 Subject: [PATCH 20/22] Update consts --- chromadb/api/client.py | 4 ++-- chromadb/db/impl/grpc/server.py | 10 ++++---- chromadb/test/client/test_database_tenant.py | 23 ++++++++++--------- .../test_multiple_clients_concurrency.py | 5 +++- chromadb/test/db/test_system.py | 10 ++++++-- 5 files changed, 31 insertions(+), 21 deletions(-) diff --git a/chromadb/api/client.py b/chromadb/api/client.py index f6a648b8070..6e436cf499f 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -126,8 +126,8 @@ class Client(SharedSystemClient, ClientAPI): # region Initialization def __init__( self, - tenant: str = "default", - database: str = "default", + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, settings: Settings = Settings(), ) -> None: super().__init__(settings=settings) diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index b050602720a..1a71929214e 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -3,7 +3,7 @@ from uuid import UUID from overrides import overrides from chromadb.ingest import CollectionAssignmentPolicy -from chromadb.config import Component, System +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System from chromadb.proto.convert import ( from_proto_metadata, from_proto_update_metadata, @@ -76,10 +76,10 @@ def reset_state(self) -> None: self._segments = {} self._tenants_to_databases_to_collections = {} # Create defaults - self._tenants_to_databases_to_collections["default"] = {} - self._tenants_to_databases_to_collections["default"]["default"] = {} - self._tenants_to_database_to_id["default"] = {} - self._tenants_to_database_to_id["default"]["default"] = UUID(int=0) + self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {} + self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {} + self._tenants_to_database_to_id[DEFAULT_TENANT] = {} + self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0) return super().reset_state() @overrides(check_signature=False) diff --git a/chromadb/test/client/test_database_tenant.py b/chromadb/test/client/test_database_tenant.py index d8778c536f4..00c502bb0ea 100644 --- a/chromadb/test/client/test_database_tenant.py +++ b/chromadb/test/client/test_database_tenant.py @@ -1,5 +1,6 @@ import pytest from chromadb.api.client import AdminClient, Client +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT def test_database_tenant_collections(client: Client) -> None: @@ -9,51 +10,51 @@ def test_database_tenant_collections(client: Client) -> None: admin_client.create_database("test_db") # Create collections in this new database - client.set_tenant(tenant="default", database="test_db") + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") client.create_collection("collection", metadata={"database": "test_db"}) # Create collections in the default database - client.set_tenant(tenant="default", database="default") - client.create_collection("collection", metadata={"database": "default"}) + client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE) + client.create_collection("collection", metadata={"database": DEFAULT_DATABASE}) # List collections in the default database collections = client.list_collections() assert len(collections) == 1 assert collections[0].name == "collection" - assert collections[0].metadata == {"database": "default"} + assert collections[0].metadata == {"database": DEFAULT_DATABASE} # List collections in the new database - client.set_tenant(tenant="default", database="test_db") + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") collections = client.list_collections() assert len(collections) == 1 assert collections[0].metadata == {"database": "test_db"} # Update the metadata in both databases to different values - client.set_tenant(tenant="default", database="default") + client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE) client.list_collections()[0].modify(metadata={"database": "default2"}) - client.set_tenant(tenant="default", database="test_db") + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") client.list_collections()[0].modify(metadata={"database": "test_db2"}) # Validate that the metadata was updated - client.set_tenant(tenant="default", database="default") + client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE) collections = client.list_collections() assert len(collections) == 1 assert collections[0].metadata == {"database": "default2"} - client.set_tenant(tenant="default", database="test_db") + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") collections = client.list_collections() assert len(collections) == 1 assert collections[0].metadata == {"database": "test_db2"} # Delete the collections and make sure databases are isolated - client.set_tenant(tenant="default", database="default") + client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE) client.delete_collection("collection") collections = client.list_collections() assert len(collections) == 0 - client.set_tenant(tenant="default", database="test_db") + client.set_tenant(tenant=DEFAULT_TENANT, database="test_db") collections = client.list_collections() assert len(collections) == 1 diff --git a/chromadb/test/client/test_multiple_clients_concurrency.py b/chromadb/test/client/test_multiple_clients_concurrency.py index b62000696fa..14054214cbf 100644 --- a/chromadb/test/client/test_multiple_clients_concurrency.py +++ b/chromadb/test/client/test_multiple_clients_concurrency.py @@ -1,5 +1,6 @@ from concurrent.futures import ThreadPoolExecutor from chromadb.api.client import AdminClient, Client +from chromadb.config import DEFAULT_TENANT def test_multiple_clients_concurrently(client: Client) -> None: @@ -21,7 +22,9 @@ def test_multiple_clients_concurrently(client: Client) -> None: # Create N clients, each on a seperate thread, each with their own database def run_target(n: int) -> None: thread_client = Client( - tenant="default", database=databases[n], settings=client._system.settings + tenant=DEFAULT_TENANT, + database=databases[n], + settings=client._system.settings, ) for collection in collections: thread_client.create_collection( diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index b08838e1dff..e801618b4bc 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -8,7 +8,13 @@ from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB -from chromadb.config import DEFAULT_TENANT, Component, System, Settings +from chromadb.config import ( + DEFAULT_DATABASE, + DEFAULT_TENANT, + Component, + System, + Settings, +) from chromadb.db.system import SysDB from chromadb.db.base import NotFoundError, UniqueConstraintError from pytest import FixtureRequest @@ -528,7 +534,7 @@ def test_create_database_with_tenants(sysdb: SysDB) -> None: # A new tenant DOES NOT have a default database. This does not error, instead 0 # results are returned - result = sysdb.get_collections(database="default", tenant="tenant1") + result = sysdb.get_collections(database=DEFAULT_DATABASE, tenant="tenant1") assert len(result) == 0 From 5df225f5ae3deb8ba437fa4c3e1efb47ceb94c60 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 16:19:23 -0700 Subject: [PATCH 21/22] patch validation --- chromadb/__init__.py | 5 +++++ chromadb/test/test_client.py | 8 +++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 6a4863e894a..174046a9b4f 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -166,6 +166,11 @@ def HttpClient( def AdminClient(settings: Settings = Settings()) -> AdminAPI: + """ + + Creates an admin client that can be used to create tenants and databases. + + """ return AdminClientCreator(settings=settings) diff --git a/chromadb/test/test_client.py b/chromadb/test/test_client.py index d4f1de9ae9f..2a8aec2764e 100644 --- a/chromadb/test/test_client.py +++ b/chromadb/test/test_client.py @@ -1,4 +1,5 @@ from typing import Generator +from unittest.mock import patch import chromadb from chromadb.api import ClientAPI import chromadb.server.fastapi @@ -24,9 +25,10 @@ def persistent_api() -> Generator[ClientAPI, None, None]: @pytest.fixture def http_api() -> Generator[ClientAPI, None, None]: - client = chromadb.HttpClient() - yield client - client.clear_system_cache() + with patch("chromadb.api.client.Client._validate_tenant_database"): + client = chromadb.HttpClient() + yield client + client.clear_system_cache() def test_ephemeral_client(ephemeral_api: ClientAPI) -> None: From abb5a88edf71fd06bc707e4e2aaff76bd3654b09 Mon Sep 17 00:00:00 2001 From: hammadb Date: Mon, 23 Oct 2023 16:24:10 -0700 Subject: [PATCH 22/22] reorder for docs --- chromadb/__init__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 174046a9b4f..eeb5eaf5899 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -165,15 +165,6 @@ def HttpClient( return ClientCreator(tenant=tenant, database=database, settings=settings) -def AdminClient(settings: Settings = Settings()) -> AdminAPI: - """ - - Creates an admin client that can be used to create tenants and databases. - - """ - return AdminClientCreator(settings=settings) - - def Client( settings: Settings = __settings, tenant: str = DEFAULT_TENANT, @@ -188,3 +179,12 @@ def Client( """ return ClientCreator(tenant=tenant, database=database, settings=settings) + + +def AdminClient(settings: Settings = Settings()) -> AdminAPI: + """ + + Creates an admin client that can be used to create tenants and databases. + + """ + return AdminClientCreator(settings=settings)