Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support : Leverage Accelerate for object detection/segmentation models #28312

Merged
merged 14 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
is_timm_available,
is_vision_available,
Expand All @@ -41,6 +42,10 @@
from .configuration_conditional_detr import ConditionalDetrConfig


if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce

if is_scipy_available():
from scipy.optimize import linear_sum_assignment

Expand Down Expand Up @@ -2507,11 +2512,12 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()

world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging
from ...utils import is_accelerate_available, is_ninja_available, logging
from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels
Expand All @@ -65,6 +65,10 @@
if is_vision_available():
from transformers.image_transforms import center_to_corners_format

if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce


class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
Expand Down Expand Up @@ -2246,11 +2250,11 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
is_timm_available,
is_vision_available,
Expand All @@ -41,6 +42,10 @@
from .configuration_detr import DetrConfig


if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce

if is_scipy_available():
from scipy.optimize import linear_sum_assignment

Expand Down Expand Up @@ -2204,11 +2209,11 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ...utils import is_accelerate_available, logging
from ...utils.backbone_utils import load_backbone
from .configuration_mask2former import Mask2FormerConfig


if is_scipy_available():
from scipy.optimize import linear_sum_assignment

if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -788,6 +792,12 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes

num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt


Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
logging,
replace_return_docstrings,
Expand All @@ -42,6 +43,10 @@
from .configuration_maskformer_swin import MaskFormerSwinConfig


if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce

if is_scipy_available():
from scipy.optimize import linear_sum_assignment

Expand Down Expand Up @@ -1194,6 +1199,12 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes

num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt


Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/oneformer/modeling_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
logging,
replace_return_docstrings,
Expand All @@ -40,6 +41,10 @@
from .configuration_oneformer import OneFormerConfig


if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -723,6 +728,12 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes

num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
is_timm_available,
is_vision_available,
Expand All @@ -50,6 +51,10 @@
if is_vision_available():
from transformers.image_transforms import center_to_corners_format

if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "TableTransformerConfig"
Expand Down Expand Up @@ -1751,11 +1756,11 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/yolos/image_processing_yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,7 @@ def preprocess(
encoded_inputs = self.pad(
images,
annotations=annotations,
return_pixel_mask=True,
return_pixel_mask=False,
data_format=data_format,
input_data_format=input_data_format,
update_bboxes=do_convert_annotations,
Expand Down
14 changes: 9 additions & 5 deletions src/transformers/models/yolos/modeling_yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
is_vision_available,
logging,
Expand All @@ -48,6 +49,9 @@
if is_vision_available():
from transformers.image_transforms import center_to_corners_format

if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -1074,11 +1078,11 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
Loading