diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 9c0b8000a146..4643ee2ab8db 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -18,6 +18,7 @@ WhereDocument, UpdateCollectionMetadata, ) +from chromadb.utils.locking import synchronized # Re-export types from chromadb.types __all__ = [ @@ -90,6 +91,7 @@ def get_settings() -> Settings: return __settings +@synchronized def EphemeralClient(settings: Settings = Settings()) -> API: """ Creates an in-memory instance of Chroma. This is useful for testing and @@ -100,6 +102,7 @@ def EphemeralClient(settings: Settings = Settings()) -> API: return Client(settings) +@synchronized def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) -> API: """ Creates a persistent instance of Chroma that saves to disk. This is useful for diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index d23139759d9e..7137d721d2d0 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -47,6 +47,8 @@ import logging import re +from chromadb.utils.locking import synchronized + logger = logging.getLogger(__name__) @@ -96,6 +98,7 @@ def __init__(self, system: System): self._topic_ns = system.settings.topic_namespace self._collection_cache = {} + @synchronized @override def heartbeat(self) -> int: return int(time.time_ns()) @@ -103,12 +106,14 @@ def heartbeat(self) -> int: # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. + @synchronized @override def create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction( + ), get_or_create: bool = False, ) -> Collection: existing = self._sysdb.get_collections(name=name) @@ -120,7 +125,8 @@ def create_collection( if get_or_create: if metadata and existing[0]["metadata"] != metadata: self._modify(id=existing[0]["id"], new_metadata=metadata) - existing = self._sysdb.get_collections(id=existing[0]["id"]) + existing = self._sysdb.get_collections( + id=existing[0]["id"]) return Collection( client=self, id=existing[0]["id"], @@ -162,12 +168,14 @@ def create_collection( embedding_function=embedding_function, ) + @synchronized @override def get_or_create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction( + ), ) -> Collection: return self.create_collection( name=name, @@ -179,11 +187,13 @@ def get_or_create_collection( # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings + @synchronized @override def get_collection( self, name: str, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction( + ), ) -> Collection: existing = self._sysdb.get_collections(name=name) @@ -198,6 +208,7 @@ def get_collection( else: raise ValueError(f"Collection {name} does not exist.") + @synchronized @override def list_collections(self) -> Sequence[Collection]: collections = [] @@ -213,6 +224,7 @@ def list_collections(self) -> Sequence[Collection]: ) return collections + @synchronized @override def _modify( self, @@ -230,12 +242,14 @@ def _modify( # TODO eventually we'll want to use OptionalArgument and Unspecified in the # signature of `_modify` but not changing the API right now. if new_name and new_metadata: - self._sysdb.update_collection(id, name=new_name, metadata=new_metadata) + self._sysdb.update_collection( + id, name=new_name, metadata=new_metadata) elif new_name: self._sysdb.update_collection(id, name=new_name) elif new_metadata: self._sysdb.update_collection(id, metadata=new_metadata) + @synchronized @override def delete_collection(self, name: str) -> None: existing = self._sysdb.get_collections(name=name) @@ -250,6 +264,7 @@ def delete_collection(self, name: str) -> None: else: raise ValueError(f"Collection {name} does not exist.") + @synchronized @override def _add( self, @@ -287,6 +302,7 @@ def _add( ) return True + @synchronized @override def _update( self, @@ -326,6 +342,7 @@ def _update( return True + @synchronized @override def _upsert( self, @@ -355,6 +372,7 @@ def _upsert( return True + @synchronized @override def _get( self, @@ -369,14 +387,16 @@ def _get( where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], ) -> GetResult: - where = validate_where(where) if where is not None and len(where) > 0 else None + where = validate_where(where) if where is not None and len( + where) > 0 else None where_document = ( validate_where_document(where_document) if where_document is not None and len(where_document) > 0 else None ) - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) + metadata_segment = self._manager.get_segment( + collection_id, MetadataReader) if sort is not None: raise NotImplementedError("Sorting is not yet supported") @@ -396,7 +416,8 @@ def _get( vectors: Sequence[t.VectorEmbeddingRecord] = [] if "embeddings" in include: vector_ids = [r["id"] for r in records] - vector_segment = self._manager.get_segment(collection_id, VectorReader) + vector_segment = self._manager.get_segment( + collection_id, VectorReader) vectors = vector_segment.get_vectors(ids=vector_ids) # TODO: Fix type so we don't need to ignore @@ -423,10 +444,12 @@ def _get( embeddings=[r["embedding"] for r in vectors] if "embeddings" in include else None, - metadatas=_clean_metadatas(metadatas) if "metadatas" in include else None, # type: ignore + metadatas=_clean_metadatas( + metadatas) if "metadatas" in include else None, # type: ignore documents=documents if "documents" in include else None, # type: ignore ) + @synchronized @override def _delete( self, @@ -435,7 +458,8 @@ def _delete( where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, ) -> IDs: - where = validate_where(where) if where is not None and len(where) > 0 else None + where = validate_where(where) if where is not None and len( + where) > 0 else None where_document = ( validate_where_document(where_document) if where_document is not None and len(where_document) > 0 @@ -464,7 +488,8 @@ def _delete( self._manager.hint_use_collection(collection_id, t.Operation.DELETE) if (where or where_document) or not ids: - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) + metadata_segment = self._manager.get_segment( + collection_id, MetadataReader) records = metadata_segment.get_metadata( where=where, where_document=where_document, ids=ids ) @@ -488,11 +513,14 @@ def _delete( ) return ids_to_delete + @synchronized @override def _count(self, collection_id: UUID) -> int: - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) + metadata_segment = self._manager.get_segment( + collection_id, MetadataReader) return metadata_segment.count() + @synchronized @override def _query( self, @@ -503,7 +531,8 @@ def _query( where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], ) -> QueryResult: - where = validate_where(where) if where is not None and len(where) > 0 else where + where = validate_where(where) if where is not None and len( + where) > 0 else where where_document = ( validate_where_document(where_document) if where_document is not None and len(where_document) > 0 @@ -516,7 +545,8 @@ def _query( for embedding in query_embeddings: self._validate_dimension(coll, len(embedding), update=False) - metadata_reader = self._manager.get_segment(collection_id, MetadataReader) + metadata_reader = self._manager.get_segment( + collection_id, MetadataReader) if where or where_document: records = metadata_reader.get_metadata( @@ -546,7 +576,8 @@ def _query( if "distances" in include: distances.append([r["distance"] for r in result]) if "embeddings" in include: - embeddings.append([cast(Embedding, r["embedding"]) for r in result]) + embeddings.append([cast(Embedding, r["embedding"]) + for r in result]) if "documents" in include or "metadatas" in include: all_ids: Set[str] = set() @@ -564,9 +595,11 @@ def _query( # queries the metadata segment. The metadata segment does not have # the record. In this case we choose to return potentially # incorrect data in the form of None. - metadata_list = [metadata_by_id.get(id, None) for id in id_list] + metadata_list = [metadata_by_id.get( + id, None) for id in id_list] if "metadatas" in include: - metadatas.append(_clean_metadatas(metadata_list)) # type: ignore + metadatas.append(_clean_metadatas( + metadata_list)) # type: ignore if "documents" in include: doc_list = [_doc(m) for m in metadata_list] documents.append(doc_list) # type: ignore @@ -592,6 +625,7 @@ def _query( documents=documents if documents else None, ) + @synchronized @override def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: return self._get(collection_id, limit=n) @@ -604,6 +638,7 @@ def get_version(self) -> str: def reset_state(self) -> None: self._collection_cache = {} + @synchronized @override def reset(self) -> bool: self._system.reset_state() @@ -630,7 +665,8 @@ def _validate_embedding_record( ) -> None: """Validate the dimension of an embedding record before submitting it to the system.""" if record["embedding"]: - self._validate_dimension(collection, len(record["embedding"]), update=True) + self._validate_dimension(collection, len( + record["embedding"]), update=True) def _validate_dimension( self, collection: t.Collection, dim: int, update: bool diff --git a/chromadb/utils/locking.py b/chromadb/utils/locking.py new file mode 100644 index 000000000000..61647fe99b86 --- /dev/null +++ b/chromadb/utils/locking.py @@ -0,0 +1,18 @@ +import threading +from functools import wraps +from typing import Callable + +lock = threading.RLock() + + +def synchronized(function: Callable) -> Callable: + """ + Decorator to synchronize a function call on a global lock. This allows us to + ensure thread safety while allowing multiple persistent or ephemeral clients to + be created. + """ + @wraps(function) + def wrapped(*args, **kwargs): + with lock: + return function(*args, **kwargs) + return wrapped