Skip to content

Commit

Permalink
[DETA] Improvement and Sync from DETA especially for training (huggin…
Browse files Browse the repository at this point in the history
…gface#27990)

* [DETA] fix freeze/unfreeze function

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: Arthur <[email protected]>

* add freeze/unfreeze test case in DETA

* fix type

* fix typo 2

* fix : enable aux and enc loss in training pipeline

* Add unsynced variables from original DETA for training

* modification for passing CI test

* make style

* make fix

* manual make fix

* change deta_modeling_test of configuration 'two_stage' default to TRUE and minor change of dist checking

* remove print

* divide configuration in DetaModel and DetaForObjectDetection

* image smaller size than 224 will give topk error

* pred_boxes and logits should be equivalent to two_stage_num_proposals

* add missing part in DetaConfig

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: amyeroberts <[email protected]>

* add docstring in configure and prettify TO DO part

* change distribute related code to accelerate

* Update src/transformers/models/deta/configuration_deta.py

Co-authored-by: amyeroberts <[email protected]>

* Update tests/models/deta/test_modeling_deta.py

Co-authored-by: amyeroberts <[email protected]>

* protect importing accelerate

* change variable name to specific value

* wrong import

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
3 people authored and wgifford committed Jan 21, 2024
1 parent 11f917f commit 111f6c9
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 40 deletions.
6 changes: 6 additions & 0 deletions src/transformers/models/deta/configuration_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ class DetaConfig(PretrainedConfig):
based on the predictions from the previous layer.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
assign_first_stage (`bool`, *optional*, defaults to `True`):
Whether to assign each prediction i to the highest overlapping ground truth object if the overlap is larger than a threshold 0.7.
assign_second_stage (`bool`, *optional*, defaults to `True`):
Whether to assign second assignment procedure in the second stage closely follows the first stage assignment procedure.
Examples:
Expand Down Expand Up @@ -161,6 +165,7 @@ def __init__(
two_stage_num_proposals=300,
with_box_refine=True,
assign_first_stage=True,
assign_second_stage=True,
class_cost=1,
bbox_cost=5,
giou_cost=2,
Expand Down Expand Up @@ -208,6 +213,7 @@ def __init__(
self.two_stage_num_proposals = two_stage_num_proposals
self.with_box_refine = with_box_refine
self.assign_first_stage = assign_first_stage
self.assign_second_stage = assign_second_stage
if two_stage is True and with_box_refine is False:
raise ValueError("If two_stage is True, with_box_refine must be True.")
# Hungarian matcher
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/deta/image_processing_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ def post_process_object_detection(
score = all_scores[b]
lbls = all_labels[b]

pre_topk = score.topk(min(10000, len(score))).indices
pre_topk = score.topk(min(10000, num_queries * num_labels)).indices
box = box[pre_topk]
score = score[pre_topk]
lbls = lbls[pre_topk]
Expand Down
62 changes: 39 additions & 23 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,18 @@
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_torchvision_available, logging, requires_backends
from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends
from ..auto import AutoBackbone
from .configuration_deta import DetaConfig


logger = logging.get_logger(__name__)


if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce

if is_vision_available():
from transformers.image_transforms import center_to_corners_format

Expand Down Expand Up @@ -105,7 +109,6 @@ class DetaDecoderOutput(ModelOutput):


@dataclass
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModelOutput with DeformableDetr->Deta,Deformable DETR->DETA
class DetaModelOutput(ModelOutput):
"""
Base class for outputs of the Deformable DETR encoder-decoder model.
Expand Down Expand Up @@ -147,6 +150,8 @@ class DetaModelOutput(ModelOutput):
foreground and background).
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
Logits of predicted bounding boxes coordinates in the first stage.
output_proposals (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
Logits of proposal bounding boxes coordinates in the gen_encoder_output_proposals.
"""

init_reference_points: torch.FloatTensor = None
Expand All @@ -161,10 +166,10 @@ class DetaModelOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
enc_outputs_class: Optional[torch.FloatTensor] = None
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
output_proposals: Optional[torch.FloatTensor] = None


@dataclass
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrObjectDetectionOutput with DeformableDetr->Deta
class DetaObjectDetectionOutput(ModelOutput):
"""
Output type of [`DetaForObjectDetection`].
Expand Down Expand Up @@ -223,6 +228,8 @@ class DetaObjectDetectionOutput(ModelOutput):
foreground and background).
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
Logits of predicted bounding boxes coordinates in the first stage.
output_proposals (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
Logits of proposal bounding boxes coordinates in the gen_encoder_output_proposals.
"""

loss: Optional[torch.FloatTensor] = None
Expand All @@ -242,6 +249,7 @@ class DetaObjectDetectionOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
enc_outputs_class: Optional = None
enc_outputs_coord_logits: Optional = None
output_proposals: Optional[torch.FloatTensor] = None


def _get_clones(module, N):
Expand Down Expand Up @@ -1632,6 +1640,7 @@ def forward(
batch_size, _, num_channels = encoder_outputs[0].shape
enc_outputs_class = None
enc_outputs_coord_logits = None
output_proposals = None
if self.config.two_stage:
object_query_embedding, output_proposals, level_ids = self.gen_encoder_output_proposals(
encoder_outputs[0], ~mask_flatten, spatial_shapes
Expand Down Expand Up @@ -1746,6 +1755,7 @@ def forward(
encoder_attentions=encoder_outputs.attentions,
enc_outputs_class=enc_outputs_class,
enc_outputs_coord_logits=enc_outputs_coord_logits,
output_proposals=output_proposals,
)


Expand Down Expand Up @@ -1804,12 +1814,15 @@ def __init__(self, config: DetaConfig):
self.post_init()

@torch.jit.unused
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection._set_aux_loss
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
aux_loss = [
{"logits": logits, "pred_boxes": pred_boxes}
for logits, pred_boxes in zip(outputs_class.transpose(0, 1)[:-1], outputs_coord.transpose(0, 1)[:-1])
]
return aux_loss

@add_start_docstrings_to_model_forward(DETA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DetaObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1929,21 +1942,25 @@ def forward(
focal_alpha=self.config.focal_alpha,
losses=losses,
num_queries=self.config.num_queries,
assign_first_stage=self.config.assign_first_stage,
assign_second_stage=self.config.assign_second_stage,
)
criterion.to(logits.device)
# Third: compute the losses, based on outputs and labels
outputs_loss = {}
outputs_loss["logits"] = logits
outputs_loss["pred_boxes"] = pred_boxes
outputs_loss["init_reference"] = init_reference
if self.config.auxiliary_loss:
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
outputs_class = self.class_embed(intermediate)
outputs_coord = self.bbox_embed(intermediate).sigmoid()
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
if self.config.two_stage:
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
outputs_loss["enc_outputs"] = {
"logits": outputs.enc_outputs_class,
"pred_boxes": enc_outputs_coord,
"anchors": outputs.output_proposals.sigmoid(),
}

loss_dict = criterion(outputs_loss, labels)
# Fourth: compute total loss, as a weighted sum of the various losses
Expand All @@ -1953,6 +1970,7 @@ def forward(
aux_weight_dict = {}
for i in range(self.config.decoder_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
aux_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

Expand Down Expand Up @@ -1983,6 +2001,7 @@ def forward(
init_reference_points=outputs.init_reference_points,
enc_outputs_class=outputs.enc_outputs_class,
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
output_proposals=outputs.output_proposals,
)

return dict_outputs
Expand Down Expand Up @@ -2192,7 +2211,7 @@ def forward(self, outputs, targets):
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
outputs_without_aux = {k: v for k, v in outputs.items() if k not in ("auxiliary_outputs", "enc_outputs")}

# Retrieve the matching between the outputs of the last layer and the targets
if self.assign_second_stage:
Expand All @@ -2203,11 +2222,12 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand All @@ -2228,17 +2248,13 @@ def forward(self, outputs, targets):
enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt["labels"] = torch.zeros_like(bt["labels"])
bt["class_labels"] = torch.zeros_like(bt["class_labels"])
if self.assign_first_stage:
indices = self.stg1_assigner(enc_outputs, bin_targets)
else:
indices = self.matcher(enc_outputs, bin_targets)
for loss in self.losses:
kwargs = {}
if loss == "labels":
# Logging is enabled only for the last layer
kwargs["log"] = False
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes)
l_dict = {k + "_enc": v for k, v in l_dict.items()}
losses.update(l_dict)

Expand Down Expand Up @@ -2662,7 +2678,7 @@ def forward(self, outputs, targets, return_cost_matrix=False):
sampled_idxs,
sampled_gt_classes,
) = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label]
matched_idxs, matched_labels, targets[b]["labels"]
matched_idxs, matched_labels, targets[b]["class_labels"]
)
pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label]
pos_gt_inds = matched_idxs[pos_pr_inds]
Expand Down Expand Up @@ -2727,7 +2743,7 @@ def forward(self, outputs, targets):
) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow]
matched_labels = self._subsample_labels(matched_labels)

all_pr_inds = torch.arange(len(anchors))
all_pr_inds = torch.arange(len(anchors), device=matched_labels.device)
pos_pr_inds = all_pr_inds[matched_labels == 1]
pos_gt_inds = matched_idxs[pos_pr_inds]
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou)
Expand Down
44 changes: 28 additions & 16 deletions tests/models/deta/test_modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@ def __init__(
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
num_queries=12,
two_stage_num_proposals=12,
num_channels=3,
image_size=196,
image_size=224,
n_targets=8,
num_labels=91,
num_feature_levels=4,
encoder_n_points=2,
decoder_n_points=6,
two_stage=False,
two_stage=True,
assign_first_stage=True,
assign_second_stage=True,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -78,6 +81,7 @@ def __init__(
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_queries = num_queries
self.two_stage_num_proposals = two_stage_num_proposals
self.num_channels = num_channels
self.image_size = image_size
self.n_targets = n_targets
Expand All @@ -86,6 +90,8 @@ def __init__(
self.encoder_n_points = encoder_n_points
self.decoder_n_points = decoder_n_points
self.two_stage = two_stage
self.assign_first_stage = assign_first_stage
self.assign_second_stage = assign_second_stage

# we also set the expected seq length for both encoder and decoder
self.encoder_seq_length = (
Expand All @@ -96,7 +102,7 @@ def __init__(
)
self.decoder_seq_length = self.num_queries

def prepare_config_and_inputs(self):
def prepare_config_and_inputs(self, model_class_name):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
Expand All @@ -114,10 +120,10 @@ def prepare_config_and_inputs(self):
target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device)
labels.append(target)

config = self.get_config()
config = self.get_config(model_class_name)
return config, pixel_values, pixel_mask, labels

def get_config(self):
def get_config(self, model_class_name):
resnet_config = ResNetConfig(
num_channels=3,
embeddings_size=10,
Expand All @@ -128,6 +134,9 @@ def get_config(self):
out_features=["stage2", "stage3", "stage4"],
out_indices=[2, 3, 4],
)
two_stage = model_class_name == "DetaForObjectDetection"
assign_first_stage = model_class_name == "DetaForObjectDetection"
assign_second_stage = model_class_name == "DetaForObjectDetection"
return DetaConfig(
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
Expand All @@ -139,16 +148,19 @@ def get_config(self):
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
num_queries=self.num_queries,
two_stage_num_proposals=self.two_stage_num_proposals,
num_labels=self.num_labels,
num_feature_levels=self.num_feature_levels,
encoder_n_points=self.encoder_n_points,
decoder_n_points=self.decoder_n_points,
two_stage=self.two_stage,
two_stage=two_stage,
assign_first_stage=assign_first_stage,
assign_second_stage=assign_second_stage,
backbone_config=resnet_config,
)

def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"):
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs(model_class_name)
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
return config, inputs_dict

Expand Down Expand Up @@ -190,14 +202,14 @@ def create_and_check_deta_object_detection_head_model(self, config, pixel_values
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)

self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4))

result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4))


@require_torchvision
Expand Down Expand Up @@ -267,19 +279,19 @@ def test_config(self):
self.config_tester.check_config_can_be_init_without_params()

def test_deta_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
self.model_tester.create_and_check_deta_model(*config_and_inputs)

def test_deta_freeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)

def test_deta_unfreeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)

def test_deta_object_detection_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaForObjectDetection")
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)

@unittest.skip(reason="DETA does not use inputs_embeds")
Expand Down

0 comments on commit 111f6c9

Please sign in to comment.