From 8a0ed0a9a2ee8712b2e2c3b20da2887ef7c55fe6 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 15 Feb 2024 14:02:58 +0000 Subject: [PATCH] Fix copies between DETR and DETA (#29037) --- src/transformers/models/deta/modeling_deta.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 188b83c4e2e280..7e1b014c834eff 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -627,7 +627,8 @@ def __init__(self, config: DetaConfig, num_heads: int, n_points: int): def _reset_parameters(self): nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0])