Skip to content

Commit

Permalink
Update embedding_functions.py
Browse files Browse the repository at this point in the history
  • Loading branch information
david20571015 authored Oct 19, 2023
1 parent ac644a8 commit bca07f6
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,19 @@ class SentenceTransformerEmbeddingFunction(EmbeddingFunction):
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
device: str = "cpu",
device: Optional[Literal["cpu", "cuda"]] = "cpu",
normalize_embeddings: bool = False,
):
"""
Initialize the SentenceTransformerEmbeddingFunction.
Args:
model_name (str, optional): The name of the model to use for text
embeddings. Defaults to "all-MiniLM-L6-v2".
device (str, optional): Device ("cuda" / "cpu") that should be
used for computation. If None, checks if a GPU can be used.
Defaults to "cpu".
"""
if model_name not in self.models:
try:
from sentence_transformers import SentenceTransformer
Expand Down

0 comments on commit bca07f6

Please sign in to comment.