Skip to content

Commit

Permalink
fix checkpointable layers logic and docstring. Add unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Dec 16, 2024
1 parent da771ed commit 6334229
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
14 changes: 12 additions & 2 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def forward(self, inputs):
partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
checkpointable_layers (list[str], optional): List of layer class names that are eligible for checkpointing. For GPT models,
ParallelTransformerLayerPipe is always checkpointed regardless of this list. If None, all layers with parameters are
considered checkpointable. Defaults to None.
dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact.
"""

Expand Down Expand Up @@ -650,9 +652,17 @@ def _is_checkpointable(self, funcs):
# because only non_reentrant_checkpoint can accept inputs with requires_grad=False
# otherwise, the backward of the embedding layer won't receive gradients.
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)
# For GPT models, checkpoint both transformer layers and any additional
# layers specified in checkpointable_layers (if provided)
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ or (
self.checkpointable_layers is not None and f.__class__.__name__ in self.checkpointable_layers)
for f in funcs)

if self.checkpointable_layers is not None:
# For non-GPT models, only checkpoint layers specified in checkpointable_layers
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)

# Default behavior: checkpoint any layer that has parameters
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
import deepspeed
from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.accelerator import get_accelerator
from copy import deepcopy
from unit.common import DistributedTest
Expand Down Expand Up @@ -259,3 +260,52 @@ def test_ckpt_non_tensor_output_ordering(self, non_tensor_output):
else:
ordering += [torch.is_tensor(non_tensor_output)]
_test_activation_checkpoint_ordering(module, ordering, inputs)


class TestCheckpointableLayersConfig(DistributedTest):
world_size = 1

def test_gpt2_checkpointable_layers(self):
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU accelerator does not support this test yet")

# Create a simple topology for testing
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
topo = PipeModelDataParallelTopology(num_pp=1, num_mp=1)

# Create test classes that we want to checkpoint
class TestTransformerLayer(torch.nn.Module):

def forward(self, x):
return x

class ParallelTransformerLayerPipe(TestTransformerLayer):
pass

class GMLPBlock(TestTransformerLayer):
pass

# Create a mock GPT2 model with different layer types
class TestGPT2ModelPipe(PipelineModule):

def __init__(self):
self.layers_spec = [
LayerSpec(ParallelTransformerLayerPipe),
LayerSpec(GMLPBlock),
LayerSpec(torch.nn.Linear, 10, 10), # Should not be checkpointed
]

super().__init__(layers=self.layers_spec,
topology=topo,
checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"])

model = TestGPT2ModelPipe()
model.to(get_accelerator().device_name())

# Build layers manually for testing
layers = [spec.build() for spec in model.layers_spec]

# Test that _is_checkpointable returns correct values
assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe
assert model._is_checkpointable([layers[1]]) == True # GMLPBlock
assert model._is_checkpointable([layers[2]]) == False # Linear layer

0 comments on commit 6334229

Please sign in to comment.