Skip to content

Commit

Permalink
reduce setting global variables to reduce torch compile graph breaks (m…
Browse files Browse the repository at this point in the history
…icrosoft#6541)

setting global variables during training will create a graph breaks when
using torch.compile (reading global variables doesn't). this commit
attempts to reduce the setting of global variables in the checkpointing
flows.
there are 2 main uses setting global variables:
1. Share data between functions
2. Establish that this is the first call to the code

For most of the cases the data in the global variables is data that can
be computed on demand or set once in an initial state in a configure
function.
For "check that this is the first run" use case the code was moved to
the configure function.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored Oct 10, 2024
1 parent a1f98bd commit d7ca3d8
Showing 1 changed file with 37 additions and 63 deletions.
100 changes: 37 additions & 63 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@

# MP parameters
mpu = None
mp_rank = None
mp_size = None

#set default values
mp_rank = 0
mp_size = 1
mp_group = None

# Model Parameters
Expand All @@ -61,8 +63,6 @@

# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
transport_stream = None
cuda_device = None


def detach_variable(inputs, device=None):
Expand Down Expand Up @@ -518,35 +518,10 @@ def save_args_for_backward(*all_args):
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
if mp_rank is None:
if mpu is not None:
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
mp_rank = mpu.get_tensor_model_parallel_rank()
mp_size = mpu.get_tensor_model_parallel_world_size()
mp_group = mpu.get_tensor_model_parallel_group()
else:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
mp_group = None

global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

if cuda_device is None:
see_memory_usage("First Forward Beginning", force=False)
if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information")
logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
logger.info(
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
logger.info(f"----Synchronization {SYNCHRONIZE}")
logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
global PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)

if PARTITION_ACTIVATIONS:
inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
Expand Down Expand Up @@ -631,8 +606,9 @@ def backward(ctx, *grads):
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")

global cuda_device, transport_stream, PARTITION_ACTIVATIONS

global PARTITION_ACTIVATIONS
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
# Rebuild deepspeed_saved_tensors
for t in ctx.deepspeed_saved_tensors:
if t is not None and hasattr(t, 'saved_data') and t.saved_data is not None:
Expand Down Expand Up @@ -764,35 +740,10 @@ def save_args_for_backward(*all_args):
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
if mp_rank is None:
if mpu is not None:
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
mp_rank = mpu.get_tensor_model_parallel_rank()
mp_size = mpu.get_tensor_model_parallel_world_size()
mp_group = mpu.get_tensor_model_parallel_group()
else:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
mp_group = None

global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

if cuda_device is None:
see_memory_usage("First Forward Beginning", force=False)
if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information")
logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
logger.info(
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
logger.info(f"----Synchronization {SYNCHRONIZE}")
logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
global PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)

if PARTITION_ACTIVATIONS:
inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
Expand Down Expand Up @@ -899,7 +850,9 @@ def replay_unpack(none_value):
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")

global cuda_device, transport_stream, PARTITION_ACTIVATIONS
global PARTITION_ACTIVATIONS
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)

# gather inputs which is partitioned or checkpointed before first forward
if PARTITION_ACTIVATIONS:
Expand Down Expand Up @@ -1152,6 +1105,27 @@ def configure(
if CONTIGUOUS_CHECKPOINTING:
assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"

global mp_rank, mp_size, mp_group

if mpu is not None:
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
mp_rank = mpu.get_tensor_model_parallel_rank()
mp_size = mpu.get_tensor_model_parallel_world_size()
mp_group = mpu.get_tensor_model_parallel_group()
else:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()

#print configuration only once
see_memory_usage("After configuration", force=False)
if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information")
logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
logger.info(f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
logger.info(f"----Synchronization {SYNCHRONIZE}")
logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")


def is_configured():
"""True if deepspeed activation checkpointing has been configured
Expand Down

0 comments on commit d7ca3d8

Please sign in to comment.