diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index a7924085fcea03..05c5cd4595b5df 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -1276,7 +1276,6 @@ def forward( if query_mask.ndim > 1: query_mask = torch.unsqueeze(query_mask, dim=-2) - pred_logits = pred_logits.to(torch.float64) pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) pred_logits = pred_logits.to(torch.float32) diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index a7d84455230670..ee6d8aa423d1cf 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -1257,7 +1257,6 @@ def forward( if query_mask.ndim > 1: query_mask = torch.unsqueeze(query_mask, dim=-2) - pred_logits = pred_logits.to(torch.float64) pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) pred_logits = pred_logits.to(torch.float32)