Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Commit

Permalink
make script more robust (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty authored Sep 22, 2023
1 parent 4d92c9c commit d7471d1
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 163 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ vllm = { version = "0.1.7", optional = true }
llama-index = { version = "^0.8.29.post1", optional = true }
# chroma
chromadb = { version = "^0.4.12", optional = true }
retrying = "^1.3.4"

[tool.poetry.extras]
anthropic_support = ["anthropic"]
Expand Down
Empty file.
157 changes: 0 additions & 157 deletions sciphi/examples/data_generation/runner.py

This file was deleted.

52 changes: 46 additions & 6 deletions sciphi/examples/populate_chroma/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from chromadb.config import Settings
from datasets import load_dataset
from openai.embeddings_utils import get_embeddings
from retrying import retry

from sciphi.core.utils import get_configured_logger

Expand All @@ -20,10 +21,36 @@ def chunk_text(text: str, chunk_size: int) -> list[str]:
return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)]


# Define a retrying decorator with specific parameters
@retry(
stop_max_attempt_number=3, wait_fixed=2000
) # Retries 3 times with a 2-second wait between retries
def robust_get_embeddings(chunks, engine):
try:
return get_embeddings(chunks, engine=engine)
except Exception as e:
logger.error(
f"Failed to get embeddings for chunks {chunks}, with error {e}"
)
raise # Reraise the exception to be caught by the retrying mechanism


MAX_EMBEDDING_BATCH_SIZE = 2048


def worker(worker_args: tuple) -> None:
"""Worker function to populate ChromaDB with a batch of entries."""
thread_name = current_thread().name
entries_batch, parsed_ids, logger, logger_interval = worker_args
(
collection,
entries_batch,
parsed_ids,
chunk_size,
batch_size,
embedding_engine,
logger,
logger_interval,
) = worker_args
logger.info(f"Starting worker thread: {thread_name}")

local_buffer: dict[str, list] = {
Expand Down Expand Up @@ -53,9 +80,13 @@ def worker(worker_args: tuple) -> None:
continue

local_buffer["documents"].extend(chunks)
local_buffer["embeddings"].extend(
get_embeddings(chunks, engine=embedding_engine)
)

for i in range(0, len(chunks), MAX_EMBEDDING_BATCH_SIZE):
batch_of_chunks = chunks[i : i + MAX_EMBEDDING_BATCH_SIZE]
local_buffer["embeddings"].extend(
get_embeddings(batch_of_chunks, engine=embedding_engine)
)

local_buffer["metadatas"].extend(
[
{
Expand Down Expand Up @@ -126,7 +157,7 @@ def batch_dataset(dataset, batch_size):
batch_size = 64
batches_per_split = 8
# Process dataset in multiple threads
num_threads = 1
num_threads = 6
# For logging
# TODO - Modify to sure we are logging by-process
log_level = "INFO"
Expand Down Expand Up @@ -183,7 +214,16 @@ def batch_dataset(dataset, batch_size):
logger.info("Creating the dataset batches...")
with ThreadPoolExecutor(max_workers=num_threads) as executor:
args_for_workers = (
(batch, parsed_ids, logger, sample_log_interval)
(
collection,
batch,
parsed_ids,
chunk_size,
batch_size,
embedding_engine,
logger,
sample_log_interval,
)
for batch in batch_dataset(dataset, batches_per_split * batch_size)
)
# The map method blocks until all results are returned
Expand Down

0 comments on commit d7471d1

Please sign in to comment.