diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 9c0d0b45ee8e51..eabae7b2b0df06 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1000,6 +1000,7 @@ def __init__(self, config) -> None: self.merger = PatchMerger( dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size ) + self.gradient_checkpointing = False def get_dtype(self) -> torch.dtype: return self.blocks[0].mlp.fc2.weight.dtype @@ -1046,7 +1047,12 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb + ) + else: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) return self.merger(hidden_states) diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index c3902c9e75bc66..f2a3719e17b4c6 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -285,24 +285,6 @@ def test_mismatching_num_image_tokens(self): image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) - @unittest.skip( - reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - def test_training_gradient_checkpointing(self): - pass - - @unittest.skip( - reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - def test_training_gradient_checkpointing_use_reentrant(self): - pass - - @unittest.skip( - reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - def test_training_gradient_checkpointing_use_reentrant_false(self): - pass - @unittest.skip(reason="Feedforward chunking is not yet supported") def test_feed_forward_chunking(self): pass