diff --git a/src/transformers/models/deta/load_custom.py b/src/transformers/models/deta/load_custom.py deleted file mode 100644 index 2f64663c3f8efd..00000000000000 --- a/src/transformers/models/deta/load_custom.py +++ /dev/null @@ -1,51 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Loading of DETA's CUDA kernels""" - -import os - - -def load_cuda_kernels(): - from torch.utils.cpp_extension import load - - root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_kernel") - src_files = [ - os.path.join(root, filename) - for filename in [ - "vision.cpp", - os.path.join("cpu", "ms_deform_attn_cpu.cpp"), - os.path.join("cuda", "ms_deform_attn_cuda.cu"), - ] - ] - - load( - "MultiScaleDeformableAttention", - src_files, - # verbose=True, - with_cuda=True, - extra_include_paths=[root], - # build_directory=os.path.dirname(os.path.realpath(__file__)), - extra_cflags=["-DWITH_CUDA=1"], - extra_cuda_cflags=[ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ], - ) - - import MultiScaleDeformableAttention as MSDA - - return MSDA diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 44abc28e8df04b..cac4b5e618fb26 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -24,8 +24,6 @@ import torch import torch.nn.functional as F from torch import Tensor, nn -from torch.autograd import Function -from torch.autograd.function import once_differentiable from ...activations import ACT2FN from ...file_utils import ( @@ -33,7 +31,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_scipy_available, - is_torch_cuda_available, is_vision_available, replace_return_docstrings, requires_backends, @@ -41,24 +38,13 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid -from ...utils import is_ninja_available, is_torchvision_available, logging +from ...utils import is_torchvision_available, logging from ..auto import AutoBackbone from .configuration_deta import DetaConfig -from .load_custom import load_cuda_kernels logger = logging.get_logger(__name__) -# Move this to not compile only when importing, this needs to happen later, like in __init__. -if is_torch_cuda_available() and is_ninja_available(): - logger.info("Loading custom CUDA kernels...") - try: - MultiScaleDeformableAttention = load_cuda_kernels() - except Exception as e: - logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") - MultiScaleDeformableAttention = None -else: - MultiScaleDeformableAttention = None if is_vision_available(): from transformers.image_transforms import center_to_corners_format @@ -66,56 +52,6 @@ if is_torchvision_available(): from torchvision.ops.boxes import batched_nms - -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction -class MultiScaleDeformableAttentionFunction(Function): - @staticmethod - def forward( - context, - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - im2col_step, - ): - context.im2col_step = im2col_step - output = MultiScaleDeformableAttention.ms_deform_attn_forward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - context.im2col_step, - ) - context.save_for_backward( - value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights - ) - return output - - @staticmethod - @once_differentiable - def backward(context, grad_output): - ( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - ) = context.saved_tensors - grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - grad_output, - context.im2col_step, - ) - - return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None - - if is_scipy_available(): from scipy.optimize import linear_sum_assignment @@ -552,12 +488,12 @@ def multi_scale_deformable_attention( return output.transpose(1, 2).contiguous() -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->Deta,Deformable DETR->DETA class DetaMultiscaleDeformableAttention(nn.Module): """ Multiscale deformable attention as proposed in Deformable DETR. """ + # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention.__init__ with DeformableDetr->Deta def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): super().__init__() if embed_dim % num_heads != 0: @@ -661,19 +597,8 @@ def forward( ) else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - try: - # GPU - output = MultiScaleDeformableAttentionFunction.apply( - value, - spatial_shapes, - level_start_index, - sampling_locations, - attention_weights, - self.im2col_step, - ) - except Exception: - # CPU - output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + # PyTorch implementation (for now) + output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) output = self.output_proj(output) return output, attention_weights