Skip to content

Commit

Permalink
feat: Thread-safety for persistent and ephemeral clients
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Oct 12, 2023
1 parent c0e307e commit 4c0417c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 18 deletions.
3 changes: 3 additions & 0 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
WhereDocument,
UpdateCollectionMetadata,
)
from chromadb.utils.locking import synchronized

# Re-export types from chromadb.types
__all__ = [
Expand Down Expand Up @@ -90,6 +91,7 @@ def get_settings() -> Settings:
return __settings


@synchronized
def EphemeralClient(settings: Settings = Settings()) -> API:
"""
Creates an in-memory instance of Chroma. This is useful for testing and
Expand All @@ -100,6 +102,7 @@ def EphemeralClient(settings: Settings = Settings()) -> API:
return Client(settings)


@synchronized
def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) -> API:
"""
Creates a persistent instance of Chroma that saves to disk. This is useful for
Expand Down
72 changes: 54 additions & 18 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import logging
import re

from chromadb.utils.locking import synchronized


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,19 +98,22 @@ def __init__(self, system: System):
self._topic_ns = system.settings.topic_namespace
self._collection_cache = {}

@synchronized
@override
def heartbeat(self) -> int:
return int(time.time_ns())

# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
# necessary because changing the value type from `Any` to`` `Union[str, int, float]`
# causes the system to somehow convert all values to strings.
@synchronized
@override
def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
get_or_create: bool = False,
) -> Collection:
existing = self._sysdb.get_collections(name=name)
Expand All @@ -120,7 +125,8 @@ def create_collection(
if get_or_create:
if metadata and existing[0]["metadata"] != metadata:
self._modify(id=existing[0]["id"], new_metadata=metadata)
existing = self._sysdb.get_collections(id=existing[0]["id"])
existing = self._sysdb.get_collections(
id=existing[0]["id"])
return Collection(
client=self,
id=existing[0]["id"],
Expand Down Expand Up @@ -162,12 +168,14 @@ def create_collection(
embedding_function=embedding_function,
)

@synchronized
@override
def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
) -> Collection:
return self.create_collection(
name=name,
Expand All @@ -179,11 +187,13 @@ def get_or_create_collection(
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
# necessary because changing the value type from `Any` to`` `Union[str, int, float]`
# causes the system to somehow convert all values to strings
@synchronized
@override
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
) -> Collection:
existing = self._sysdb.get_collections(name=name)

Expand All @@ -198,6 +208,7 @@ def get_collection(
else:
raise ValueError(f"Collection {name} does not exist.")

@synchronized
@override
def list_collections(self) -> Sequence[Collection]:
collections = []
Expand All @@ -213,6 +224,7 @@ def list_collections(self) -> Sequence[Collection]:
)
return collections

@synchronized
@override
def _modify(
self,
Expand All @@ -230,12 +242,14 @@ def _modify(
# TODO eventually we'll want to use OptionalArgument and Unspecified in the
# signature of `_modify` but not changing the API right now.
if new_name and new_metadata:
self._sysdb.update_collection(id, name=new_name, metadata=new_metadata)
self._sysdb.update_collection(
id, name=new_name, metadata=new_metadata)
elif new_name:
self._sysdb.update_collection(id, name=new_name)
elif new_metadata:
self._sysdb.update_collection(id, metadata=new_metadata)

@synchronized
@override
def delete_collection(self, name: str) -> None:
existing = self._sysdb.get_collections(name=name)
Expand All @@ -250,6 +264,7 @@ def delete_collection(self, name: str) -> None:
else:
raise ValueError(f"Collection {name} does not exist.")

@synchronized
@override
def _add(
self,
Expand Down Expand Up @@ -287,6 +302,7 @@ def _add(
)
return True

@synchronized
@override
def _update(
self,
Expand Down Expand Up @@ -326,6 +342,7 @@ def _update(

return True

@synchronized
@override
def _upsert(
self,
Expand Down Expand Up @@ -355,6 +372,7 @@ def _upsert(

return True

@synchronized
@override
def _get(
self,
Expand All @@ -369,14 +387,16 @@ def _get(
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"],
) -> GetResult:
where = validate_where(where) if where is not None and len(where) > 0 else None
where = validate_where(where) if where is not None and len(
where) > 0 else None
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
else None
)

metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
metadata_segment = self._manager.get_segment(
collection_id, MetadataReader)

if sort is not None:
raise NotImplementedError("Sorting is not yet supported")
Expand All @@ -396,7 +416,8 @@ def _get(
vectors: Sequence[t.VectorEmbeddingRecord] = []
if "embeddings" in include:
vector_ids = [r["id"] for r in records]
vector_segment = self._manager.get_segment(collection_id, VectorReader)
vector_segment = self._manager.get_segment(
collection_id, VectorReader)
vectors = vector_segment.get_vectors(ids=vector_ids)

# TODO: Fix type so we don't need to ignore
Expand All @@ -423,10 +444,12 @@ def _get(
embeddings=[r["embedding"] for r in vectors]
if "embeddings" in include
else None,
metadatas=_clean_metadatas(metadatas) if "metadatas" in include else None, # type: ignore
metadatas=_clean_metadatas(
metadatas) if "metadatas" in include else None, # type: ignore
documents=documents if "documents" in include else None, # type: ignore
)

@synchronized
@override
def _delete(
self,
Expand All @@ -435,7 +458,8 @@ def _delete(
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
) -> IDs:
where = validate_where(where) if where is not None and len(where) > 0 else None
where = validate_where(where) if where is not None and len(
where) > 0 else None
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
Expand Down Expand Up @@ -464,7 +488,8 @@ def _delete(
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)

if (where or where_document) or not ids:
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
metadata_segment = self._manager.get_segment(
collection_id, MetadataReader)
records = metadata_segment.get_metadata(
where=where, where_document=where_document, ids=ids
)
Expand All @@ -488,11 +513,14 @@ def _delete(
)
return ids_to_delete

@synchronized
@override
def _count(self, collection_id: UUID) -> int:
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
metadata_segment = self._manager.get_segment(
collection_id, MetadataReader)
return metadata_segment.count()

@synchronized
@override
def _query(
self,
Expand All @@ -503,7 +531,8 @@ def _query(
where_document: WhereDocument = {},
include: Include = ["documents", "metadatas", "distances"],
) -> QueryResult:
where = validate_where(where) if where is not None and len(where) > 0 else where
where = validate_where(where) if where is not None and len(
where) > 0 else where
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
Expand All @@ -516,7 +545,8 @@ def _query(
for embedding in query_embeddings:
self._validate_dimension(coll, len(embedding), update=False)

metadata_reader = self._manager.get_segment(collection_id, MetadataReader)
metadata_reader = self._manager.get_segment(
collection_id, MetadataReader)

if where or where_document:
records = metadata_reader.get_metadata(
Expand Down Expand Up @@ -546,7 +576,8 @@ def _query(
if "distances" in include:
distances.append([r["distance"] for r in result])
if "embeddings" in include:
embeddings.append([cast(Embedding, r["embedding"]) for r in result])
embeddings.append([cast(Embedding, r["embedding"])
for r in result])

if "documents" in include or "metadatas" in include:
all_ids: Set[str] = set()
Expand All @@ -564,9 +595,11 @@ def _query(
# queries the metadata segment. The metadata segment does not have
# the record. In this case we choose to return potentially
# incorrect data in the form of None.
metadata_list = [metadata_by_id.get(id, None) for id in id_list]
metadata_list = [metadata_by_id.get(
id, None) for id in id_list]
if "metadatas" in include:
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore
metadatas.append(_clean_metadatas(
metadata_list)) # type: ignore
if "documents" in include:
doc_list = [_doc(m) for m in metadata_list]
documents.append(doc_list) # type: ignore
Expand All @@ -592,6 +625,7 @@ def _query(
documents=documents if documents else None,
)

@synchronized
@override
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
return self._get(collection_id, limit=n)
Expand All @@ -604,6 +638,7 @@ def get_version(self) -> str:
def reset_state(self) -> None:
self._collection_cache = {}

@synchronized
@override
def reset(self) -> bool:
self._system.reset_state()
Expand All @@ -630,7 +665,8 @@ def _validate_embedding_record(
) -> None:
"""Validate the dimension of an embedding record before submitting it to the system."""
if record["embedding"]:
self._validate_dimension(collection, len(record["embedding"]), update=True)
self._validate_dimension(collection, len(
record["embedding"]), update=True)

def _validate_dimension(
self, collection: t.Collection, dim: int, update: bool
Expand Down
18 changes: 18 additions & 0 deletions chromadb/utils/locking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import threading
from functools import wraps
from typing import Callable

lock = threading.RLock()


def synchronized(function: Callable) -> Callable:
"""
Decorator to synchronize a function call on a global lock. This allows us to
ensure thread safety while allowing multiple persistent or ephemeral clients to
be created.
"""
@wraps(function)
def wrapped(*args, **kwargs):
with lock:
return function(*args, **kwargs)
return wrapped

0 comments on commit 4c0417c

Please sign in to comment.