diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 2aee1beb2b4..7781c422572 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -16,6 +16,7 @@ WhereDocument, ) from inspect import signature +from tenacity import retry # Re-export types from chromadb.types __all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"] @@ -194,6 +195,9 @@ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings: setattr(cls, "__call__", __call__) + def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings: + return retry(**retry_kwargs)(self.__call__)(input) + def validate_embedding_function( embedding_function: EmbeddingFunction[Embeddable], diff --git a/chromadb/cli/utils.py b/chromadb/cli/utils.py index d5ef9d95836..383715b1b72 100644 --- a/chromadb/cli/utils.py +++ b/chromadb/cli/utils.py @@ -3,13 +3,15 @@ import yaml -def set_log_file_path(log_config_path: str, new_filename: str = "chroma.log") -> Dict[str, Any]: +def set_log_file_path( + log_config_path: str, new_filename: str = "chroma.log" +) -> Dict[str, Any]: """This works with the standard log_config.yml file. It will not work with custom log configs that may use different handlers""" - with open(f"{log_config_path}", 'r') as file: + with open(f"{log_config_path}", "r") as file: log_config = yaml.safe_load(file) - for handler in log_config['handlers'].values(): - if handler.get('class') == 'logging.handlers.RotatingFileHandler': - handler['filename'] = new_filename + for handler in log_config["handlers"].values(): + if handler.get("class") == "logging.handlers.RotatingFileHandler": + handler["filename"] = new_filename return log_config diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 7259842a429..ec5fc05e3ee 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -15,6 +15,7 @@ is_image, is_document, ) + from pathlib import Path import os import tarfile @@ -73,11 +74,14 @@ def __init__( self._normalize_embeddings = normalize_embeddings def __call__(self, input: Documents) -> Embeddings: - return cast(Embeddings, self._model.encode( - list(input), - convert_to_numpy=True, - normalize_embeddings=self._normalize_embeddings, - ).tolist()) + return cast( + Embeddings, + self._model.encode( + list(input), + convert_to_numpy=True, + normalize_embeddings=self._normalize_embeddings, + ).tolist(), + ) class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): @@ -91,7 +95,9 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): self._model = SentenceModel(model_name_or_path=model_name) def __call__(self, input: Documents) -> Embeddings: - return cast(Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist()) # noqa E501 + return cast( + Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist() + ) # noqa E501 class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): @@ -184,12 +190,10 @@ def __call__(self, input: Documents) -> Embeddings: ).data # Sort resulting embeddings by index - sorted_embeddings = sorted( - embeddings, key=lambda e: e.index - ) + sorted_embeddings = sorted(embeddings, key=lambda e: e.index) # Return just the embeddings - return cast(Embeddings, [result.embedding for result in sorted_embeddings]) + return cast(Embeddings, [result.embedding for result in sorted_embeddings]) else: if self._api_type == "azure": embeddings = self._client.create( @@ -201,9 +205,7 @@ def __call__(self, input: Documents) -> Embeddings: ] # Sort resulting embeddings by index - sorted_embeddings = sorted( - embeddings, key=lambda e: e["index"] - ) + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # Return just the embeddings return cast( @@ -269,9 +271,13 @@ def __call__(self, input: Documents) -> Embeddings: >>> embeddings = hugging_face(texts) """ # Call HuggingFace Embedding API for each document - return cast(Embeddings, self._session.post( - self._api_url, json={"inputs": input, "options": {"wait_for_model": True}} - ).json()) + return cast( + Embeddings, + self._session.post( + self._api_url, + json={"inputs": input, "options": {"wait_for_model": True}}, + ).json(), + ) class JinaEmbeddingFunction(EmbeddingFunction[Documents]): @@ -716,7 +722,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): def __init__( self, - session: "boto3.Session", # Quote for forward reference + session: "boto3.Session", # noqa: F821 # Quote for forward reference model_name: str = "amazon.titan-embed-text-v1", **kwargs: Any, ): @@ -798,9 +804,9 @@ def __call__(self, input: Documents) -> Embeddings: >>> embeddings = hugging_face(texts) """ # Call HuggingFace Embedding Server API for each document - return cast (Embeddings,self._session.post( - self._api_url, json={"inputs": input} - ).json()) + return cast( + Embeddings, self._session.post(self._api_url, json={"inputs": input}).json() + ) # List of all classes in this module