From a3eddca6ce56bc36a1c119cd95d1e0ec6b7fc18f Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Thu, 12 Oct 2023 23:09:37 +0300 Subject: [PATCH] fix: Added @synchronized to EphemeralClient - Added a multi-treaded synchronization test --- chromadb/__init__.py | 1 + chromadb/test/test_multithreaded.py | 82 +++++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 4b2bf8cddef..4643ee2ab8d 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -91,6 +91,7 @@ def get_settings() -> Settings: return __settings +@synchronized def EphemeralClient(settings: Settings = Settings()) -> API: """ Creates an in-memory instance of Chroma. This is useful for testing and diff --git a/chromadb/test/test_multithreaded.py b/chromadb/test/test_multithreaded.py index 57c259dad99..b69f64019dc 100644 --- a/chromadb/test/test_multithreaded.py +++ b/chromadb/test/test_multithreaded.py @@ -1,10 +1,11 @@ import multiprocessing from concurrent.futures import Future, ThreadPoolExecutor, wait import random +import uuid import threading from typing import Any, Dict, List, Optional, Set, Tuple, cast import numpy as np - +import chromadb from chromadb.api import API import chromadb.test.property.invariants as invariants from chromadb.test.property.strategies import RecordSet @@ -64,7 +65,8 @@ def _test_multithreaded_add(api: API, N: int, D: int, num_workers: int) -> None: coll.add, ids=ids[start:end], embeddings=embeddings[start:end] if embeddings is not None else None, - metadatas=metadatas[start:end] if metadatas is not None else None, # type: ignore + # type: ignore + metadatas=metadatas[start:end] if metadatas is not None else None, documents=documents[start:end] if documents is not None else None, ) futures.append(future) @@ -141,7 +143,8 @@ def perform_operation( n_results = 5 with lock: currently_added_ids = list(added_ids.copy()) - currently_added_indices = [ids.index(id) for id in currently_added_ids] + currently_added_indices = [ + ids.index(id) for id in currently_added_ids] if ( len(currently_added_ids) == 0 or len(currently_added_indices) < n_results @@ -173,7 +176,8 @@ def perform_operation( to_send = min(batch_size, len(ids) - total_sent) start = total_sent + 1 end = total_sent + to_send + 1 - future = executor.submit(perform_operation, operation, ids[start:end]) + future = executor.submit( + perform_operation, operation, ids[start:end]) futures.append(future) total_sent += to_send elif operation == 1: @@ -219,3 +223,73 @@ def test_interleaved_add_query(api: API) -> None: num_workers = random.randint(2, multiprocessing.cpu_count() * 2) N, D = generate_data_shape() _test_interleaved_add_query(api, N, D, num_workers) + + +def _test_serialized_access(N: int, D: int, num_workers: int) -> None: + records_set = generate_record_set(N, D) + ids = records_set["ids"] + embeddings = records_set["embeddings"] + metadatas = records_set["metadatas"] + documents = records_set["documents"] + + print(f"Adding {N} records with {D} dimensions on {num_workers} workers") + + # TODO: batch_size and sync_threshold should be configurable + def _create_client_and_add(**add_args) -> None: + # api = chromadb.PersistentClient(path="./test-db") + api = chromadb.EphemeralClient() + coll = api.create_collection( + name=f"test-{uuid.uuid4()}", metadata=test_hnsw_config) + coll.add(**add_args) + _t_record_set = RecordSet(**add_args) + # Check that invariants hold + invariants.count(coll, _t_record_set) + invariants.ids_match(coll, _t_record_set) + invariants.metadatas_match(coll, _t_record_set) + invariants.no_duplicates(coll) + # Check that the ANN accuracy is good + # On a random subset of the dataset + query_indices = random.sample([i for i in range(N)], 10) + n_results = 5 + invariants.ann_accuracy( + coll, + records_set, + n_results=n_results, + query_indices=query_indices, + ) + with ThreadPoolExecutor(max_workers=num_workers) as executor: + + futures: List[Future[Any]] = [] + total_sent = -1 + while total_sent < len(ids): + # Randomly grab up to 10% of the dataset and send it to the executor + batch_size = random.randint(1, N // 10) + to_send = min(batch_size, len(ids) - total_sent) + start = total_sent + 1 + end = total_sent + to_send + 1 + if embeddings is not None and len(embeddings[start:end]) == 0: + break + future = executor.submit( + _create_client_and_add, + ids=ids[start:end], + embeddings=embeddings[start:end] if embeddings is not None else None, + # type: ignore + metadatas=metadatas[start:end] if metadatas is not None else None, + documents=documents[start:end] if documents is not None else None, + ) + futures.append(future) + total_sent += to_send + + wait(futures) + + for future in futures: + exception = future.exception() + if exception is not None: + raise exception + + +def test_serialized_access() -> None: + for i in range(3): + num_workers = random.randint(2, multiprocessing.cpu_count() * 2) + N, D = generate_data_shape() + _test_serialized_access(N, D, num_workers)