diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index aaef53c01e2..3e128be5f73 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -30,9 +30,19 @@ class SentenceTransformerEmbeddingFunction(EmbeddingFunction): def __init__( self, model_name: str = "all-MiniLM-L6-v2", - device: str = "cpu", + device: Optional[Literal["cpu", "cuda"]] = "cpu", normalize_embeddings: bool = False, ): + """ + Initialize the SentenceTransformerEmbeddingFunction. + + Args: + model_name (str, optional): The name of the model to use for text + embeddings. Defaults to "all-MiniLM-L6-v2". + device (str, optional): Device ("cuda" / "cpu") that should be + used for computation. If None, checks if a GPU can be used. + Defaults to "cpu". + """ if model_name not in self.models: try: from sentence_transformers import SentenceTransformer