Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Support MPS acceleration in OpenCLIP embeddings
Currently, attempting to use OpenCLIP embeddings on a metal performance shader (MPS) device results in the following error: ``` RuntimeError: slow_conv2d_forward_mps: input(device='cpu') and weight(device=mps:0') must be on the same device in add. ``` These changes fix this error by explicitly moving input tensors to the model device in `OpenCLIPEmbeddingFunction` embedding methods. This provides a significant speedup (~2x on my M1 MacBook Pro) compared to running on CPU.
- Loading branch information