Skip to content

Commit

Permalink
[Grounding DINO] Add support for cross-attention in GroundingDinoMult…
Browse files Browse the repository at this point in the history
…iHeadAttention (#30364)

* Added cross attention support

* Fixed dtypes

* Fixed assumption

* Moved to decoder
  • Loading branch information
EduardoPach authored and Ita Zaporozhets committed May 14, 2024
1 parent d62e085 commit 8c6a7c6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/transformers/models/grounding_dino/modeling_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def forward(
attention_masks = attention_masks[:, None, :, :]
attention_masks = attention_masks.repeat(1, self.num_heads, 1, 1)

dtype = torch.float16
dtype = hidden_states.dtype
attention_masks = attention_masks.to(dtype=dtype) # fp16 compatibility
attention_masks = (1.0 - attention_masks) * torch.finfo(dtype).min

Expand Down Expand Up @@ -1425,12 +1425,11 @@ def forward(

# Cross-Attention Text
queries = self.with_pos_embed(hidden_states, position_embeddings)

hidden_states, text_cross_attn_weights = self.encoder_attn_text(
queries=queries,
keys=text_encoder_hidden_states,
values=text_encoder_hidden_states,
# attention_mask=text_encoder_attention_mask, # TODO fix cross-attention mask here
attention_mask=text_encoder_attention_mask,
output_attentions=True,
)

Expand Down Expand Up @@ -1893,6 +1892,16 @@ def forward(
intermediate = ()
intermediate_reference_points = ()

if text_encoder_attention_mask is not None:
dtype = text_encoder_hidden_states.dtype

text_encoder_attention_mask = text_encoder_attention_mask[:, None, None, :]
text_encoder_attention_mask = text_encoder_attention_mask.repeat(
1, self.config.decoder_attention_heads, self.config.num_queries, 1
)
text_encoder_attention_mask = text_encoder_attention_mask.to(dtype=dtype)
text_encoder_attention_mask = text_encoder_attention_mask * torch.finfo(dtype).min

for idx, decoder_layer in enumerate(self.layers):
num_coordinates = reference_points.shape[-1]
if num_coordinates == 4:
Expand Down
26 changes: 26 additions & 0 deletions tests/models/grounding_dino/test_modeling_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,29 @@ def test_inference_object_detection_head_equivalence_cpu_gpu(self):

self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3))
self.assertTrue(torch.allclose(results_cpu["boxes"], result_gpu["boxes"].cpu(), atol=1e-3))

def test_cross_attention_mask(self):
model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(torch_device)

processor = self.default_processor
image = prepare_img()
text1 = "a cat."
text2 = "a remote control."
text_batched = [text1, text2]

encoding1 = processor(images=image, text=text1, return_tensors="pt").to(torch_device)
encoding2 = processor(images=image, text=text2, return_tensors="pt").to(torch_device)
# If we batch the text and cross attention masking is working the batched result should be equal to
# The singe text result
encoding_batched = processor(
images=[image] * len(text_batched), text=text_batched, padding="longest", return_tensors="pt"
).to(torch_device)

with torch.no_grad():
outputs1 = model(**encoding1)
outputs2 = model(**encoding2)
outputs_batched = model(**encoding_batched)

self.assertTrue(torch.allclose(outputs1.logits, outputs_batched.logits[:1], atol=1e-3))
# For some reason 12 elements are > 1e-3, but the rest are fine
self.assertTrue(torch.allclose(outputs2.logits, outputs_batched.logits[1:], atol=1.8e-3))

0 comments on commit 8c6a7c6

Please sign in to comment.