diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 26cf843357c641..7b546fcb4905de 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -1656,7 +1656,11 @@ def unfreeze_backbone(self): param.requires_grad_(True) @lru_cache(maxsize=32) - def generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device="cpu"): + def generate_anchors(self, spatial_shapes=None, grid_size=0.05): + # We always generate anchors in float32 to preserve equivalence between + # dynamic and static anchor inference + dtype = torch.float32 + if spatial_shapes is None: spatial_shapes = [ [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] @@ -1674,7 +1678,7 @@ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.floa anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) # define the valid range for anchor coordinates eps = 1e-2 - anchors = torch.concat(anchors, 1).to(device) + anchors = torch.concat(anchors, 1) valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) anchors = torch.log(anchors / (1 - anchors)) anchors = torch.where(valid_mask, anchors, torch.inf) @@ -1769,15 +1773,15 @@ def forward( # Prepare encoder inputs (by flattening) source_flatten = [] - spatial_shapes = [] + spatial_shapes_list = [] for level, source in enumerate(sources): batch_size, num_channels, height, width = source.shape spatial_shape = (height, width) - spatial_shapes.append(spatial_shape) + spatial_shapes_list.append(spatial_shape) source = source.flatten(2).transpose(1, 2) source_flatten.append(source) source_flatten = torch.cat(source_flatten, 1) - spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) # prepare denoising training @@ -1805,9 +1809,14 @@ def forward( # prepare input for decoder if self.training or self.config.anchor_image_size is None: - anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device, dtype=dtype) + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple) else: - anchors, valid_mask = self.anchors.to(device, dtype), self.valid_mask.to(device, dtype) + anchors, valid_mask = self.anchors, self.valid_mask + + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) # use the valid_mask to selectively retain values in the feature map where the mask is `True` memory = valid_mask.to(source_flatten.dtype) * source_flatten diff --git a/tests/models/rt_detr/test_modeling_rt_detr.py b/tests/models/rt_detr/test_modeling_rt_detr.py index 2d3d48dba33125..7de4d3b869d12b 100644 --- a/tests/models/rt_detr/test_modeling_rt_detr.py +++ b/tests/models/rt_detr/test_modeling_rt_detr.py @@ -16,6 +16,7 @@ import inspect import math +import tempfile import unittest from parameterized import parameterized @@ -630,6 +631,48 @@ def test_inference_with_different_dtypes(self, torch_dtype_str): with torch.no_grad(): _ = model(**self._prepare_for_class(inputs_dict, model_class)) + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_gpu + @slow + def test_inference_equivalence_for_static_and_dynamic_anchors(self, torch_dtype_str): + torch_dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[torch_dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + h, w = inputs_dict["pixel_values"].shape[-2:] + + # convert inputs to the desired dtype + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(torch_dtype) + + for model_class in self.all_model_classes: + with tempfile.TemporaryDirectory() as tmpdirname: + model_class(config).save_pretrained(tmpdirname) + model_static = model_class.from_pretrained( + tmpdirname, anchor_image_size=[h, w], device_map=torch_device, torch_dtype=torch_dtype + ).eval() + model_dynamic = model_class.from_pretrained( + tmpdirname, anchor_image_size=None, device_map=torch_device, torch_dtype=torch_dtype + ).eval() + + self.assertIsNotNone(model_static.config.anchor_image_size) + self.assertIsNone(model_dynamic.config.anchor_image_size) + + with torch.no_grad(): + outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class)) + outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class)) + + self.assertTrue( + torch.allclose( + outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=1e-4, atol=1e-4 + ), + f"Max diff: {(outputs_static.last_hidden_state - outputs_dynamic.last_hidden_state).abs().max()}", + ) + TOLERANCE = 1e-4