From 8c6a7c697e5e87ac02446a099c1991a4c4ced219 Mon Sep 17 00:00:00 2001 From: Eduardo Pacheco <69953243+EduardoPach@users.noreply.github.com> Date: Tue, 23 Apr 2024 10:56:14 +0200 Subject: [PATCH] [Grounding DINO] Add support for cross-attention in GroundingDinoMultiHeadAttention (#30364) * Added cross attention support * Fixed dtypes * Fixed assumption * Moved to decoder --- .../grounding_dino/modeling_grounding_dino.py | 15 ++++++++--- .../test_modeling_grounding_dino.py | 26 +++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index a98901015c94c6..83009c92504211 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -818,7 +818,7 @@ def forward( attention_masks = attention_masks[:, None, :, :] attention_masks = attention_masks.repeat(1, self.num_heads, 1, 1) - dtype = torch.float16 + dtype = hidden_states.dtype attention_masks = attention_masks.to(dtype=dtype) # fp16 compatibility attention_masks = (1.0 - attention_masks) * torch.finfo(dtype).min @@ -1425,12 +1425,11 @@ def forward( # Cross-Attention Text queries = self.with_pos_embed(hidden_states, position_embeddings) - hidden_states, text_cross_attn_weights = self.encoder_attn_text( queries=queries, keys=text_encoder_hidden_states, values=text_encoder_hidden_states, - # attention_mask=text_encoder_attention_mask, # TODO fix cross-attention mask here + attention_mask=text_encoder_attention_mask, output_attentions=True, ) @@ -1893,6 +1892,16 @@ def forward( intermediate = () intermediate_reference_points = () + if text_encoder_attention_mask is not None: + dtype = text_encoder_hidden_states.dtype + + text_encoder_attention_mask = text_encoder_attention_mask[:, None, None, :] + text_encoder_attention_mask = text_encoder_attention_mask.repeat( + 1, self.config.decoder_attention_heads, self.config.num_queries, 1 + ) + text_encoder_attention_mask = text_encoder_attention_mask.to(dtype=dtype) + text_encoder_attention_mask = text_encoder_attention_mask * torch.finfo(dtype).min + for idx, decoder_layer in enumerate(self.layers): num_coordinates = reference_points.shape[-1] if num_coordinates == 4: diff --git a/tests/models/grounding_dino/test_modeling_grounding_dino.py b/tests/models/grounding_dino/test_modeling_grounding_dino.py index 42486f92da9746..1231baff7c6c73 100644 --- a/tests/models/grounding_dino/test_modeling_grounding_dino.py +++ b/tests/models/grounding_dino/test_modeling_grounding_dino.py @@ -687,3 +687,29 @@ def test_inference_object_detection_head_equivalence_cpu_gpu(self): self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3)) self.assertTrue(torch.allclose(results_cpu["boxes"], result_gpu["boxes"].cpu(), atol=1e-3)) + + def test_cross_attention_mask(self): + model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(torch_device) + + processor = self.default_processor + image = prepare_img() + text1 = "a cat." + text2 = "a remote control." + text_batched = [text1, text2] + + encoding1 = processor(images=image, text=text1, return_tensors="pt").to(torch_device) + encoding2 = processor(images=image, text=text2, return_tensors="pt").to(torch_device) + # If we batch the text and cross attention masking is working the batched result should be equal to + # The singe text result + encoding_batched = processor( + images=[image] * len(text_batched), text=text_batched, padding="longest", return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs1 = model(**encoding1) + outputs2 = model(**encoding2) + outputs_batched = model(**encoding_batched) + + self.assertTrue(torch.allclose(outputs1.logits, outputs_batched.logits[:1], atol=1e-3)) + # For some reason 12 elements are > 1e-3, but the rest are fine + self.assertTrue(torch.allclose(outputs2.logits, outputs_batched.logits[1:], atol=1.8e-3))