From 7e7817a7b493760ba5f57040b5bca985f5b85e14 Mon Sep 17 00:00:00 2001 From: Connor Brinton Date: Fri, 13 Dec 2024 09:26:03 -0500 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Support=20MPS=20acceleration=20in?= =?UTF-8?q?=20OpenCLIP=20embeddings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, attempting to use OpenCLIP embeddings on a metal performance shader (MPS) device results in the following error: ``` RuntimeError: slow_conv2d_forward_mps: input(device='cpu') and weight(device=mps:0') must be on the same device in add. ``` These changes fix this error by explicitly moving input tensors to the model device in `OpenCLIPEmbeddingFunction` embedding methods. This provides a significant speedup (~2x on my M1 MacBook Pro) compared to running on CPU. --- .../embedding_functions/open_clip_embedding_function.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/chromadb/utils/embedding_functions/open_clip_embedding_function.py b/chromadb/utils/embedding_functions/open_clip_embedding_function.py index 0d05b6c27b6..a261e4ed2e7 100644 --- a/chromadb/utils/embedding_functions/open_clip_embedding_function.py +++ b/chromadb/utils/embedding_functions/open_clip_embedding_function.py @@ -47,6 +47,7 @@ def __init__( model, _, preprocess = open_clip.create_model_and_transforms( model_name=model_name, pretrained=checkpoint ) + self._device = device self._model = model self._model.to(device) self._preprocess = preprocess @@ -56,14 +57,16 @@ def _encode_image(self, image: Image) -> Embedding: pil_image = self._PILImage.fromarray(image) with self._torch.no_grad(): image_features = self._model.encode_image( - self._preprocess(pil_image).unsqueeze(0) + self._preprocess(pil_image).unsqueeze(0).to(self._device) ) image_features /= image_features.norm(dim=-1, keepdim=True) return cast(Embedding, image_features.squeeze().cpu().numpy()) def _encode_text(self, text: Document) -> Embedding: with self._torch.no_grad(): - text_features = self._model.encode_text(self._tokenizer(text)) + text_features = self._model.encode_text( + self._tokenizer(text).to(self._device) + ) text_features /= text_features.norm(dim=-1, keepdim=True) return cast(Embedding, text_features.squeeze().cpu().numpy())