From 9ddceddc217ec07c14545408595af6f6310deefb Mon Sep 17 00:00:00 2001 From: fpgaminer Date: Thu, 31 Oct 2024 10:51:15 -0700 Subject: [PATCH] Bug Fix for issue #34294 (#34295) Update SiglipVisionEmbeddings.forward to cast input to correct dtype before embedding it. --- src/transformers/models/siglip/modeling_siglip.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index a3d06cbb4792b4..a42bcd0e17461e 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -308,7 +308,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: _, _, height, width = pixel_values.shape - patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(2).transpose(1, 2) if interpolate_pos_encoding: