Skip to content

Commit

Permalink
DeformableDetrModel support fp16 (#29013)
Browse files Browse the repository at this point in the history
* Update ms_deform_attn_cuda.cu

* Update ms_deform_attn_cuda.cuh

* Update modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update modeling_deformable_detr.py

* python utils/check_copies.py --fix_and_overwrite

* Fix dtype missmatch error

* Update test_modeling_deformable_detr.py

* Update test_modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* Update modeling_deformable_detr.py

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
DonggeunYu and amyeroberts authored Feb 15, 2024
1 parent 83e96dc commit 5b6fa23
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
Expand Down Expand Up @@ -134,7 +134,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
Expand Down Expand Up @@ -142,7 +142,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):

def _reset_parameters(self):
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads)
default_dtype = torch.get_default_dtype()
thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
Expand Down Expand Up @@ -1171,8 +1172,8 @@ def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for level, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
indexing="ij",
)
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
Expand Down Expand Up @@ -1540,15 +1541,15 @@ def unfreeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
param.requires_grad_(True)

def get_valid_ratio(self, mask):
def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps."""

_, height, width = mask.shape
valid_height = torch.sum(mask[:, :, 0], 1)
valid_width = torch.sum(mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height
valid_ratio_width = valid_width.float() / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
valid_ratio_height = valid_height.to(dtype) / height
valid_ratio_width = valid_width.to(dtype) / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
return valid_ratio

def get_proposal_pos_embed(self, proposals):
Expand Down Expand Up @@ -1721,7 +1722,7 @@ def forward(
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
valid_ratios = valid_ratios.float()

# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,15 +1549,15 @@ def unfreeze_backbone(self):
param.requires_grad_(True)

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio
def get_valid_ratio(self, mask):
def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps."""

_, height, width = mask.shape
valid_height = torch.sum(mask[:, :, 0], 1)
valid_width = torch.sum(mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height
valid_ratio_width = valid_width.float() / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
valid_ratio_height = valid_height.to(dtype) / height
valid_ratio_width = valid_width.to(dtype) / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
return valid_ratio

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_proposal_pos_embed
Expand Down
12 changes: 12 additions & 0 deletions tests/models/deformable_detr/test_modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,18 @@ def test_two_stage_training(self):
loss = model(**inputs).loss
loss.backward()

def create_and_check_model_fp16_forward(self):
model_class = DeformableDetrForObjectDetection
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

model = model_class(config)
model.to(torch_device)
model.half()
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
output = model(**inputs)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())


TOLERANCE = 1e-4

Expand Down

0 comments on commit 5b6fa23

Please sign in to comment.