Skip to content

Commit

Permalink
[WIP] [ENH] add exponential backoff and jitter to embedding calls (#1526
Browse files Browse the repository at this point in the history
)

This is a WIP, closes #1524

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Use `tenacity` to add exponential backoff and jitter
 - New functionality
- control the parameters of the exponential backoff and jitter and allow
the user to use their own wait functions from `tenacity`'s API

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
None
  • Loading branch information
rancomp authored Jan 17, 2024
1 parent 7aaf36f commit 9824336
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 25 deletions.
4 changes: 4 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
WhereDocument,
)
from inspect import signature
from tenacity import retry

# Re-export types from chromadb.types
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]
Expand Down Expand Up @@ -194,6 +195,9 @@ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:

setattr(cls, "__call__", __call__)

def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
return retry(**retry_kwargs)(self.__call__)(input)


def validate_embedding_function(
embedding_function: EmbeddingFunction[Embeddable],
Expand Down
12 changes: 7 additions & 5 deletions chromadb/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import yaml


def set_log_file_path(log_config_path: str, new_filename: str = "chroma.log") -> Dict[str, Any]:
def set_log_file_path(
log_config_path: str, new_filename: str = "chroma.log"
) -> Dict[str, Any]:
"""This works with the standard log_config.yml file.
It will not work with custom log configs that may use different handlers"""
with open(f"{log_config_path}", 'r') as file:
with open(f"{log_config_path}", "r") as file:
log_config = yaml.safe_load(file)
for handler in log_config['handlers'].values():
if handler.get('class') == 'logging.handlers.RotatingFileHandler':
handler['filename'] = new_filename
for handler in log_config["handlers"].values():
if handler.get("class") == "logging.handlers.RotatingFileHandler":
handler["filename"] = new_filename

return log_config
46 changes: 26 additions & 20 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_image,
is_document,
)

from pathlib import Path
import os
import tarfile
Expand Down Expand Up @@ -73,11 +74,14 @@ def __init__(
self._normalize_embeddings = normalize_embeddings

def __call__(self, input: Documents) -> Embeddings:
return cast(Embeddings, self._model.encode(
list(input),
convert_to_numpy=True,
normalize_embeddings=self._normalize_embeddings,
).tolist())
return cast(
Embeddings,
self._model.encode(
list(input),
convert_to_numpy=True,
normalize_embeddings=self._normalize_embeddings,
).tolist(),
)


class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]):
Expand All @@ -91,7 +95,9 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"):
self._model = SentenceModel(model_name_or_path=model_name)

def __call__(self, input: Documents) -> Embeddings:
return cast(Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist()) # noqa E501
return cast(
Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist()
) # noqa E501


class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]):
Expand Down Expand Up @@ -184,12 +190,10 @@ def __call__(self, input: Documents) -> Embeddings:
).data

# Sort resulting embeddings by index
sorted_embeddings = sorted(
embeddings, key=lambda e: e.index
)
sorted_embeddings = sorted(embeddings, key=lambda e: e.index)

# Return just the embeddings
return cast(Embeddings, [result.embedding for result in sorted_embeddings])
return cast(Embeddings, [result.embedding for result in sorted_embeddings])
else:
if self._api_type == "azure":
embeddings = self._client.create(
Expand All @@ -201,9 +205,7 @@ def __call__(self, input: Documents) -> Embeddings:
]

# Sort resulting embeddings by index
sorted_embeddings = sorted(
embeddings, key=lambda e: e["index"]
)
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])

# Return just the embeddings
return cast(
Expand Down Expand Up @@ -269,9 +271,13 @@ def __call__(self, input: Documents) -> Embeddings:
>>> embeddings = hugging_face(texts)
"""
# Call HuggingFace Embedding API for each document
return cast(Embeddings, self._session.post(
self._api_url, json={"inputs": input, "options": {"wait_for_model": True}}
).json())
return cast(
Embeddings,
self._session.post(
self._api_url,
json={"inputs": input, "options": {"wait_for_model": True}},
).json(),
)


class JinaEmbeddingFunction(EmbeddingFunction[Documents]):
Expand Down Expand Up @@ -716,7 +722,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
session: "boto3.Session", # Quote for forward reference
session: "boto3.Session", # noqa: F821 # Quote for forward reference
model_name: str = "amazon.titan-embed-text-v1",
**kwargs: Any,
):
Expand Down Expand Up @@ -798,9 +804,9 @@ def __call__(self, input: Documents) -> Embeddings:
>>> embeddings = hugging_face(texts)
"""
# Call HuggingFace Embedding Server API for each document
return cast (Embeddings,self._session.post(
self._api_url, json={"inputs": input}
).json())
return cast(
Embeddings, self._session.post(self._api_url, json={"inputs": input}).json()
)


# List of all classes in this module
Expand Down

0 comments on commit 9824336

Please sign in to comment.