Skip to content

Commit

Permalink
fix: Added @synchronized to EphemeralClient
Browse files Browse the repository at this point in the history
- Added a multi-treaded synchronization test
  • Loading branch information
tazarov committed Oct 12, 2023
1 parent 759b306 commit a3eddca
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
1 change: 1 addition & 0 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def get_settings() -> Settings:
return __settings


@synchronized
def EphemeralClient(settings: Settings = Settings()) -> API:
"""
Creates an in-memory instance of Chroma. This is useful for testing and
Expand Down
82 changes: 78 additions & 4 deletions chromadb/test/test_multithreaded.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import multiprocessing
from concurrent.futures import Future, ThreadPoolExecutor, wait
import random
import uuid
import threading
from typing import Any, Dict, List, Optional, Set, Tuple, cast
import numpy as np

import chromadb
from chromadb.api import API
import chromadb.test.property.invariants as invariants
from chromadb.test.property.strategies import RecordSet
Expand Down Expand Up @@ -64,7 +65,8 @@ def _test_multithreaded_add(api: API, N: int, D: int, num_workers: int) -> None:
coll.add,
ids=ids[start:end],
embeddings=embeddings[start:end] if embeddings is not None else None,
metadatas=metadatas[start:end] if metadatas is not None else None, # type: ignore
# type: ignore
metadatas=metadatas[start:end] if metadatas is not None else None,
documents=documents[start:end] if documents is not None else None,
)
futures.append(future)
Expand Down Expand Up @@ -141,7 +143,8 @@ def perform_operation(
n_results = 5
with lock:
currently_added_ids = list(added_ids.copy())
currently_added_indices = [ids.index(id) for id in currently_added_ids]
currently_added_indices = [
ids.index(id) for id in currently_added_ids]
if (
len(currently_added_ids) == 0
or len(currently_added_indices) < n_results
Expand Down Expand Up @@ -173,7 +176,8 @@ def perform_operation(
to_send = min(batch_size, len(ids) - total_sent)
start = total_sent + 1
end = total_sent + to_send + 1
future = executor.submit(perform_operation, operation, ids[start:end])
future = executor.submit(
perform_operation, operation, ids[start:end])
futures.append(future)
total_sent += to_send
elif operation == 1:
Expand Down Expand Up @@ -219,3 +223,73 @@ def test_interleaved_add_query(api: API) -> None:
num_workers = random.randint(2, multiprocessing.cpu_count() * 2)
N, D = generate_data_shape()
_test_interleaved_add_query(api, N, D, num_workers)


def _test_serialized_access(N: int, D: int, num_workers: int) -> None:
records_set = generate_record_set(N, D)
ids = records_set["ids"]
embeddings = records_set["embeddings"]
metadatas = records_set["metadatas"]
documents = records_set["documents"]

print(f"Adding {N} records with {D} dimensions on {num_workers} workers")

# TODO: batch_size and sync_threshold should be configurable
def _create_client_and_add(**add_args) -> None:
# api = chromadb.PersistentClient(path="./test-db")
api = chromadb.EphemeralClient()
coll = api.create_collection(
name=f"test-{uuid.uuid4()}", metadata=test_hnsw_config)
coll.add(**add_args)
_t_record_set = RecordSet(**add_args)
# Check that invariants hold
invariants.count(coll, _t_record_set)
invariants.ids_match(coll, _t_record_set)
invariants.metadatas_match(coll, _t_record_set)
invariants.no_duplicates(coll)
# Check that the ANN accuracy is good
# On a random subset of the dataset
query_indices = random.sample([i for i in range(N)], 10)
n_results = 5
invariants.ann_accuracy(
coll,
records_set,
n_results=n_results,
query_indices=query_indices,
)
with ThreadPoolExecutor(max_workers=num_workers) as executor:

futures: List[Future[Any]] = []
total_sent = -1
while total_sent < len(ids):
# Randomly grab up to 10% of the dataset and send it to the executor
batch_size = random.randint(1, N // 10)
to_send = min(batch_size, len(ids) - total_sent)
start = total_sent + 1
end = total_sent + to_send + 1
if embeddings is not None and len(embeddings[start:end]) == 0:
break
future = executor.submit(
_create_client_and_add,
ids=ids[start:end],
embeddings=embeddings[start:end] if embeddings is not None else None,
# type: ignore
metadatas=metadatas[start:end] if metadatas is not None else None,
documents=documents[start:end] if documents is not None else None,
)
futures.append(future)
total_sent += to_send

wait(futures)

for future in futures:
exception = future.exception()
if exception is not None:
raise exception


def test_serialized_access() -> None:
for i in range(3):
num_workers = random.randint(2, multiprocessing.cpu_count() * 2)
N, D = generate_data_shape()
_test_serialized_access(N, D, num_workers)

0 comments on commit a3eddca

Please sign in to comment.