From 8635802af9c9ec1c9fa30541992994b88e3ec1d3 Mon Sep 17 00:00:00 2001 From: g-prz <158364710+g-prz@users.noreply.github.com> Date: Tue, 1 Oct 2024 22:08:57 +0300 Subject: [PATCH] Move weight initilization deformabledetr (#33339) * fix(copy): fixup copy * fix(deformable_detr): move weight initialization to the right place * fix(grounding_dino): move weight initialization to the right place * fix(rt_detr): move weight initialization to the right place * [run-slow] deformable_detr, grounding_dino, rt_detr --- .../modeling_deformable_detr.py | 45 +++++++++--------- .../grounding_dino/modeling_grounding_dino.py | 45 +++++++++--------- .../models/rt_detr/modeling_rt_detr.py | 46 +++++++++---------- 3 files changed, 65 insertions(+), 71 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 46e00787baf618..1084e7136a428f 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -660,29 +660,6 @@ def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int): self.disable_custom_kernels = config.disable_custom_kernels - self._reset_parameters() - - def _reset_parameters(self): - nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - default_dtype = torch.get_default_dtype() - thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = ( - (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) - .view(self.n_heads, 1, 1, 2) - .repeat(1, self.n_levels, self.n_points, 1) - ) - for i in range(self.n_points): - grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): - self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(self.attention_weights.weight.data, 0.0) - nn.init.constant_(self.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(self.value_proj.weight.data) - nn.init.constant_(self.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(self.output_proj.weight.data) - nn.init.constant_(self.output_proj.bias.data, 0.0) - def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): return tensor if position_embeddings is None else tensor + position_embeddings @@ -1088,7 +1065,27 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) elif isinstance(module, DeformableDetrMultiscaleDeformableAttention): - module._reset_parameters() + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 3b298704de32fb..08e4b27af64d7d 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -664,29 +664,6 @@ def __init__(self, config: GroundingDinoConfig, num_heads: int, n_points: int): self.disable_custom_kernels = config.disable_custom_kernels - self._reset_parameters() - - def _reset_parameters(self): - nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - default_dtype = torch.get_default_dtype() - thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = ( - (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) - .view(self.n_heads, 1, 1, 2) - .repeat(1, self.n_levels, self.n_points, 1) - ) - for i in range(self.n_points): - grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): - self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(self.attention_weights.weight.data, 0.0) - nn.init.constant_(self.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(self.value_proj.weight.data) - nn.init.constant_(self.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(self.output_proj.weight.data) - nn.init.constant_(self.output_proj.bias.data, 0.0) - def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): return tensor if position_embeddings is None else tensor + position_embeddings @@ -1509,7 +1486,27 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) elif isinstance(module, GroundingDinoMultiscaleDeformableAttention): - module._reset_parameters() + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) elif isinstance(module, GroundingDinoBiMultiHeadAttention): nn.init.xavier_uniform_(module.vision_proj.weight) module.vision_proj.bias.data.fill_(0) diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 4e32434901cdc7..35af2ec8ecfb48 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -816,29 +816,6 @@ def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int): self.disable_custom_kernels = config.disable_custom_kernels - self._reset_parameters() - - def _reset_parameters(self): - nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - default_dtype = torch.get_default_dtype() - thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = ( - (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) - .view(self.n_heads, 1, 1, 2) - .repeat(1, self.n_levels, self.n_points, 1) - ) - for i in range(self.n_points): - grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): - self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(self.attention_weights.weight.data, 0.0) - nn.init.constant_(self.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(self.value_proj.weight.data) - nn.init.constant_(self.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(self.output_proj.weight.data) - nn.init.constant_(self.output_proj.bias.data, 0.0) - def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): return tensor if position_embeddings is None else tensor + position_embeddings @@ -1176,6 +1153,29 @@ def _init_weights(self, module): nn.init.constant_(layer.layers[-1].weight, 0) nn.init.constant_(layer.layers[-1].bias, 0) + if isinstance(module, RTDetrMultiscaleDeformableAttention): + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) + if isinstance(module, RTDetrModel): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob))