diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 188b83c4e2e280..7e1b014c834eff 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -627,7 +627,8 @@ def __init__(self, config: DetaConfig, num_heads: int, n_points: int): def _reset_parameters(self): nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0])