Skip to content

Commit

Permalink
feat: Device param
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Mar 4, 2024
1 parent 091e466 commit 1d0f324
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,10 @@ def __call__(self, input: Documents) -> Embeddings:

class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]):
def __init__(
self, model_name: str = "ViT-B-32", checkpoint: str = "laion2b_s34b_b79k"
self,
model_name: str = "ViT-B-32",
checkpoint: str = "laion2b_s34b_b79k",
device: Optional[str] = "cpu",
) -> None:
try:
import open_clip
Expand All @@ -697,6 +700,7 @@ def __init__(
model_name=model_name, pretrained=checkpoint
)
self._model = model
self._model.to(device)
self._preprocess = preprocess
self._tokenizer = open_clip.get_tokenizer(model_name=model_name)

Expand Down

0 comments on commit 1d0f324

Please sign in to comment.