Skip to content

Commit

Permalink
Support gradient checkpointing in Qwen2VL ViT (#34724)
Browse files Browse the repository at this point in the history
* Support gradient checkpointing in Qwen2VL ViT

* Enable gradient checkpoint tests for Qwen2VL

* [run-slow] qwen2_vl
  • Loading branch information
li-plus authored Nov 19, 2024
1 parent 1a0cd69 commit 0db91c3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 19 deletions.
8 changes: 7 additions & 1 deletion src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,7 @@ def __init__(self, config) -> None:
self.merger = PatchMerger(
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
)
self.gradient_checkpointing = False

def get_dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
Expand Down Expand Up @@ -1046,7 +1047,12 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

for blk in self.blocks:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb
)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)

return self.merger(hidden_states)

Expand Down
18 changes: 0 additions & 18 deletions tests/models/qwen2_vl/test_modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,24 +285,6 @@ def test_mismatching_num_image_tokens(self):
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass
Expand Down

0 comments on commit 0db91c3

Please sign in to comment.