Skip to content

Commit

Permalink
Fix copies between DETR and DETA (#29037)
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts authored Feb 15, 2024
1 parent 5b6fa23 commit 8a0ed0a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ def __init__(self, config: DetaConfig, num_heads: int, n_points: int):

def _reset_parameters(self):
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads)
default_dtype = torch.get_default_dtype()
thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
Expand Down

0 comments on commit 8a0ed0a

Please sign in to comment.