Skip to content

Commit

Permalink
Remove custom kernel code
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jan 26, 2023
1 parent ab92eff commit 76d090f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 130 deletions.
51 changes: 0 additions & 51 deletions src/transformers/models/deta/load_custom.py

This file was deleted.

83 changes: 4 additions & 79 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,98 +24,34 @@
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
is_torch_cuda_available,
is_vision_available,
replace_return_docstrings,
requires_backends,
)
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, is_torchvision_available, logging
from ...utils import is_torchvision_available, logging
from ..auto import AutoBackbone
from .configuration_deta import DetaConfig
from .load_custom import load_cuda_kernels


logger = logging.get_logger(__name__)

# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else:
MultiScaleDeformableAttention = None

if is_vision_available():
from transformers.image_transforms import center_to_corners_format

if is_torchvision_available():
from torchvision.ops.boxes import batched_nms


# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
context,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
context.im2col_step = im2col_step
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
)
context.save_for_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
)
return output

@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)

return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None


if is_scipy_available():
from scipy.optimize import linear_sum_assignment

Expand Down Expand Up @@ -552,12 +488,12 @@ def multi_scale_deformable_attention(
return output.transpose(1, 2).contiguous()


# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->Deta,Deformable DETR->DETA
class DetaMultiscaleDeformableAttention(nn.Module):
"""
Multiscale deformable attention as proposed in Deformable DETR.
"""

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention.__init__ with DeformableDetr->Deta
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
Expand Down Expand Up @@ -661,19 +597,8 @@ def forward(
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
try:
# GPU
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# CPU
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
# PyTorch implementation (for now)
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)

return output, attention_weights
Expand Down

0 comments on commit 76d090f

Please sign in to comment.