Skip to content

Commit

Permalink
Migrate API to BaseAPI, ClientAPI, ServerAPI. Wrap and modify the rel…
Browse files Browse the repository at this point in the history
…evant systems. Change all tests to use ServerAPI
  • Loading branch information
HammadB committed Oct 17, 2023
1 parent e564860 commit 99b8b75
Show file tree
Hide file tree
Showing 23 changed files with 890 additions and 122 deletions.
77 changes: 56 additions & 21 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Dict
import logging
from chromadb.api.client import Client as ClientCreator
import chromadb.config
from chromadb.config import Settings, System
from chromadb.api import API
from chromadb.config import Settings
from chromadb.api import ClientAPI
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
CollectionMetadata,
Expand Down Expand Up @@ -35,8 +36,6 @@
"QueryResult",
"GetResult",
]
from chromadb.telemetry.events import ClientStartEvent
from chromadb.telemetry import Telemetry


logger = logging.getLogger(__name__)
Expand All @@ -56,12 +55,14 @@
is_client = False
try:
from chromadb.is_thin_client import is_thin_client # type: ignore

is_client = is_thin_client
except ImportError:
is_client = False

if not is_client:
import sqlite3

if sqlite3.sqlite_version_info < (3, 35, 0):
if IN_COLAB:
# In Colab, hotswap to pysqlite-binary if it's too old
Expand Down Expand Up @@ -90,7 +91,7 @@ def get_settings() -> Settings:
return __settings


def EphemeralClient(settings: Settings = Settings()) -> API:
def EphemeralClient(settings: Settings = Settings()) -> ClientAPI:
"""
Creates an in-memory instance of Chroma. This is useful for testing and
development, but not recommended for production use.
Expand All @@ -100,7 +101,12 @@ def EphemeralClient(settings: Settings = Settings()) -> API:
return Client(settings)


def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) -> API:
def PersistentClient(
path: str = "./chroma",
tenant: str = "default",
namespace: str = "default",
settings: Settings = Settings(),
) -> ClientAPI:
"""
Creates a persistent instance of Chroma that saves to disk. This is useful for
testing and development, but not recommended for production use.
Expand All @@ -111,16 +117,18 @@ def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) ->
settings.persist_directory = path
settings.is_persistent = True

return Client(settings)
return ClientCreator(tenant=tenant, namespace=namespace, settings=settings)


def HttpClient(
host: str = "localhost",
port: str = "8000",
ssl: bool = False,
headers: Dict[str, str] = {},
tenant: str = "default",
namespace: str = "default",
settings: Settings = Settings(),
) -> API:
) -> ClientAPI:
"""
Creates a client that connects to a remote Chroma server. This supports
many clients connecting to the same server, and is the recommended way to
Expand All @@ -139,20 +147,47 @@ def HttpClient(
settings.chroma_server_ssl_enabled = ssl
settings.chroma_server_headers = headers

return Client(settings)
return ClientCreator(tenant=tenant, namespace=namespace, settings=settings)


def Client(settings: Settings = __settings) -> API:
# TODO: replace default tenant and namespace strings with constants
def Client(
settings: Settings = __settings, tenant: str = "default", namespace: str = "default"
) -> ClientAPI:
"""Return a running chroma.API instance"""

system = System(settings)

telemetry_client = system.instance(Telemetry)
api = system.instance(API)

system.start()

# Submit event for client start
telemetry_client.capture(ClientStartEvent())

return api
# Change this to actually check if an "API" instance already exists, wrap it in a
# tenant/namespace aware "Client", and return it
# this way we can support multiple clients in the same process but using the same
# chroma instance

# API is thread safe, so we can just return the same instance
# This way a "Client" will just be a wrapper around an API instance that is
# tenant/namespace aware

# To do this we will
# 1. Have a global dict of API instances, keyed by path
# 2. When a client is requested, check if one exists in the dict, and if so check if its
# settings match the requested settings
# 3. If the settings match, construct a new Client that wraps the existing API instance with
# the tenant/namespace
# 4. If the settings don't match, error out because we don't support changing the settings
# got a given database
# 5. If no client exists in the dict, create a new API instance, wrap it in a Client, and
# add it to the dict

# The hierarchy then becomes
# For local
# Path -> Tenant -> Namespace -> API
# For remote
# Host -> Tenant -> Namespace -> API

# A given API for a path is a singleton, and is shared between all tenants and namespaces
# for that path

# A DB exists at a path or host, and has tenants and namespaces

# All our tests currently use system.instance(API) assuming thats the root object
# This is likely fine,

return ClientCreator(tenant=tenant, namespace=namespace, settings=settings)
222 changes: 219 additions & 3 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from typing import Sequence, Optional
from uuid import UUID

from overrides import override
from chromadb.config import DEFAULT_NAMESPACE, DEFAULT_TENANT
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
CollectionMetadata,
Expand All @@ -19,7 +22,7 @@
import chromadb.utils.embedding_functions as ef


class API(Component, ABC):
class BaseAPI(ABC):
@abstractmethod
def heartbeat(self) -> int:
"""Get the current time in nanoseconds since epoch.
Expand Down Expand Up @@ -371,10 +374,10 @@ def get_version(self) -> str:

@abstractmethod
def get_settings(self) -> Settings:
"""Get the settings used to initialize the client.
"""Get the settings used to initialize.
Returns:
Settings: The settings used to initialize the client.
Settings: The settings used to initialize.
"""
pass
Expand All @@ -385,3 +388,216 @@ def max_batch_size(self) -> int:
"""Return the maximum number of records that can be submitted in a single call
to submit_embeddings."""
pass


class ClientAPI(BaseAPI, ABC):
@abstractmethod
def set_namespace(self, namespace: str) -> None:
"""Set the namespace for the client.
Args:
namespace: The namespace to set.
"""
pass

@abstractmethod
def set_tenant(self, tenant: str) -> None:
"""Set the tenant for the client.
Args:
tenant: The tenant to set.
"""
pass

@staticmethod
@abstractmethod
def clear_system_cache() -> None:
"""Clear the system cache so that new systems can be created for an existing path.
This should only be used for testing purposes."""
pass


class ServerAPI(BaseAPI, Component):
"""An API instance that extends the relevant Base API methods by passing
in a tenant and namespace. This is the root component of the Chroma System"""

@abstractmethod
@override
def list_collections(
self, tenant: str = DEFAULT_TENANT, namespace: str = DEFAULT_NAMESPACE
) -> Sequence[Collection]:
pass

@abstractmethod
@override
def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> Collection:
pass

@abstractmethod
@override
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> Collection:
pass

@abstractmethod
@override
def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> Collection:
pass

@abstractmethod
@override
def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> None:
pass

@abstractmethod
@override
def delete_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> None:
pass

#
# ITEM METHODS
#

@abstractmethod
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> bool:
pass

@abstractmethod
@override
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> bool:
pass

@abstractmethod
@override
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> bool:
pass

@abstractmethod
@override
def _count(
self,
collection_id: UUID,
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> int:
pass

@abstractmethod
@override
def _peek(
self,
collection_id: UUID,
n: int = 10,
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> GetResult:
pass

@abstractmethod
@override
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = {},
sort: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
page: Optional[int] = None,
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"],
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> GetResult:
pass

@abstractmethod
@override
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs],
where: Optional[Where] = {},
where_document: Optional[WhereDocument] = {},
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> IDs:
pass

@abstractmethod
@override
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
n_results: int = 10,
where: Where = {},
where_document: WhereDocument = {},
include: Include = ["embeddings", "metadatas", "documents", "distances"],
tenant: str = DEFAULT_TENANT,
namespace: str = DEFAULT_NAMESPACE,
) -> QueryResult:
pass
Loading

0 comments on commit 99b8b75

Please sign in to comment.