diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 9fbd91f750a9..49fa2807c355 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -116,7 +116,9 @@ def forward(self, inputs): partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'. activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. - checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering. + checkpointable_layers (list[str], optional): List of layer class names that are eligible for checkpointing. For GPT models, + ParallelTransformerLayerPipe is always checkpointed regardless of this list. If None, all layers with parameters are + considered checkpointable. Defaults to None. dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact. """ @@ -650,9 +652,17 @@ def _is_checkpointable(self, funcs): # because only non_reentrant_checkpoint can accept inputs with requires_grad=False # otherwise, the backward of the embedding layer won't receive gradients. if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'): - return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs) + # For GPT models, checkpoint both transformer layers and any additional + # layers specified in checkpointable_layers (if provided) + return all('ParallelTransformerLayerPipe' in f.__class__.__name__ or ( + self.checkpointable_layers is not None and f.__class__.__name__ in self.checkpointable_layers) + for f in funcs) + if self.checkpointable_layers is not None: + # For non-GPT models, only checkpoint layers specified in checkpointable_layers return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs) + + # Default behavior: checkpoint any layer that has parameters params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] return any(len(list(p)) > 0 for p in params) diff --git a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py index 22a61003b31e..dd3bcd7fb6bd 100644 --- a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py +++ b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py @@ -8,6 +8,7 @@ import pytest import torch import deepspeed +from deepspeed.pipe import PipelineModule, LayerSpec from deepspeed.accelerator import get_accelerator from copy import deepcopy from unit.common import DistributedTest @@ -259,3 +260,52 @@ def test_ckpt_non_tensor_output_ordering(self, non_tensor_output): else: ordering += [torch.is_tensor(non_tensor_output)] _test_activation_checkpoint_ordering(module, ordering, inputs) + + +class TestCheckpointableLayersConfig(DistributedTest): + world_size = 1 + + def test_gpt2_checkpointable_layers(self): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") + + # Create a simple topology for testing + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology + topo = PipeModelDataParallelTopology(num_pp=1, num_mp=1, num_dp=1) + + # Create test classes that we want to checkpoint + class TestTransformerLayer(torch.nn.Module): + + def forward(self, x): + return x + + class ParallelTransformerLayerPipe(TestTransformerLayer): + pass + + class GMLPBlock(TestTransformerLayer): + pass + + # Create a mock GPT2 model with different layer types + class TestGPT2ModelPipe(PipelineModule): + + def __init__(self): + self.layers_spec = [ + LayerSpec(ParallelTransformerLayerPipe), + LayerSpec(GMLPBlock), + LayerSpec(torch.nn.Linear, 10, 10), # Should not be checkpointed + ] + + super().__init__(layers=self.layers_spec, + topology=topo, + checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"]) + + model = TestGPT2ModelPipe() + model.to(get_accelerator().device_name()) + + # Build layers manually for testing + layers = [spec.build() for spec in model.layers_spec] + + # Test that _is_checkpointable returns correct values + assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe + assert model._is_checkpointable([layers[1]]) == True # GMLPBlock + assert model._is_checkpointable([layers[2]]) == False # Linear layer