From 5b6fa2306add0cb06dd1a0ecd708633e8c7e5e58 Mon Sep 17 00:00:00 2001 From: Donggeun Yu Date: Thu, 15 Feb 2024 21:31:09 +0900 Subject: [PATCH] DeformableDetrModel support fp16 (#29013) * Update ms_deform_attn_cuda.cu * Update ms_deform_attn_cuda.cuh * Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update modeling_deformable_detr.py * python utils/check_copies.py --fix_and_overwrite * Fix dtype missmatch error * Update test_modeling_deformable_detr.py * Update test_modeling_deformable_detr.py * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../deformable_detr/cuda/ms_deform_attn_cuda.cu | 4 ++-- .../cuda/ms_deform_attn_cuda.cuh | 4 ++-- .../deformable_detr/modeling_deformable_detr.py | 17 +++++++++-------- src/transformers/models/deta/modeling_deta.py | 8 ++++---- .../test_modeling_deformable_detr.py | 12 ++++++++++++ 5 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu b/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu index 8ea1d7fabe2684..e8e265219cc38d 100644 --- a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +++ b/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu @@ -64,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward( for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), @@ -134,7 +134,7 @@ std::vector ms_deform_attn_cuda_backward( for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, diff --git a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cuh b/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cuh index 34f8ae9cb77bba..5bde73a5a96b8b 100644 --- a/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cuh +++ b/src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cuh @@ -72,7 +72,7 @@ at::Tensor ms_deform_attn_cuda_forward( for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), @@ -142,7 +142,7 @@ std::vector ms_deform_attn_cuda_backward( for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 001d379e9a1324..3c6e48a6226221 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -617,7 +617,8 @@ def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int): def _reset_parameters(self): nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) @@ -1171,8 +1172,8 @@ def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for level, (height, width) in enumerate(spatial_shapes): ref_y, ref_x = meshgrid( - torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), - torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), + torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), + torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), indexing="ij", ) # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36 @@ -1540,15 +1541,15 @@ def unfreeze_backbone(self): for name, param in self.backbone.conv_encoder.model.named_parameters(): param.requires_grad_(True) - def get_valid_ratio(self, mask): + def get_valid_ratio(self, mask, dtype=torch.float32): """Get the valid ratio of all feature maps.""" _, height, width = mask.shape valid_height = torch.sum(mask[:, :, 0], 1) valid_width = torch.sum(mask[:, 0, :], 1) - valid_ratio_heigth = valid_height.float() / height - valid_ratio_width = valid_width.float() / width - valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) + valid_ratio_height = valid_height.to(dtype) / height + valid_ratio_width = valid_width.to(dtype) / width + valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) return valid_ratio def get_proposal_pos_embed(self, proposals): @@ -1721,7 +1722,7 @@ def forward( lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) - valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1) valid_ratios = valid_ratios.float() # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index ddecd59474f3ea..188b83c4e2e280 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1549,15 +1549,15 @@ def unfreeze_backbone(self): param.requires_grad_(True) # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio - def get_valid_ratio(self, mask): + def get_valid_ratio(self, mask, dtype=torch.float32): """Get the valid ratio of all feature maps.""" _, height, width = mask.shape valid_height = torch.sum(mask[:, :, 0], 1) valid_width = torch.sum(mask[:, 0, :], 1) - valid_ratio_heigth = valid_height.float() / height - valid_ratio_width = valid_width.float() / width - valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) + valid_ratio_height = valid_height.to(dtype) / height + valid_ratio_width = valid_width.to(dtype) / width + valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) return valid_ratio # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_proposal_pos_embed diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py index 2d5a0deec33c0f..c1268fff3c6e64 100644 --- a/tests/models/deformable_detr/test_modeling_deformable_detr.py +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -583,6 +583,18 @@ def test_two_stage_training(self): loss = model(**inputs).loss loss.backward() + def create_and_check_model_fp16_forward(self): + model_class = DeformableDetrForObjectDetection + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + model = model_class(config) + model.to(torch_device) + model.half() + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + output = model(**inputs)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + TOLERANCE = 1e-4