Skip to content

Commit

Permalink
Add API to get devices of offload states (microsoft#6586)
Browse files Browse the repository at this point in the history
This PR adds an API `deepspeed.runtime.zero.offload_states
get_state_devices`, which gets devices of offload states as suggested in
this
[comment](microsoft#6011 (comment)).

We could lift this up to `deepspeed.utils` but would need to resolve a
circular import: User code -> `deepspeed.utils` ->
`deepspeed.utils.offload_states` -> `deepspeed.runtime.zero` ->
`deepspeed.runtime.zero.partition_parameters` -> `deepspeed.utils`

This will require a significant refactoring as long as we have
`OffloadStateTypeEnum` in `deepspeed.runtime.zero`.

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored Oct 10, 2024
1 parent d7ca3d8 commit adec991
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 16 deletions.
10 changes: 5 additions & 5 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,28 @@
"""

from collections.abc import Iterable
from deepspeed.moe.utils import is_moe_param
import os
import psutil
import gc
from math import sqrt

from numpy import prod

import torch
from deepspeed import comm as dist
from torch.nn import functional as F
try:
from torch._six import inf
except ModuleNotFoundError:
from torch import inf

from deepspeed import comm as dist
from deepspeed.moe.utils import is_moe_param
from deepspeed.utils import groups, logger
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size,
bwc_pipeline_parallel_group)
from deepspeed.runtime.constants import PIPE_REPLICATED
from numpy import prod
from deepspeed.accelerator import get_accelerator

from deepspeed.module_inject.policy import transpose
from torch.nn import functional as F

torch_memory_reserved = get_accelerator().memory_reserved
torch_max_memory_reserved = get_accelerator().max_memory_reserved
Expand Down
74 changes: 74 additions & 0 deletions deepspeed/runtime/zero/offload_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Set
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.offload_config import OffloadStateTypeEnum

from deepspeed.utils.tensor_fragment import safe_get_local_fp32_param, safe_get_local_optimizer_state


def _make_offload_state_key(key):
return f"{key}_offload_buffer"


def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_key(state, key):
offload_buf_key = _make_offload_state_key(key)
if offload_buf_key not in state:
state[offload_buf_key] = torch.empty_like(state[key], device=device)
if pin_memory:
state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key])
state[offload_buf_key].copy_(state[key], non_blocking=non_blocking)
state[key].data = state[offload_buf_key]

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_key(state, "exp_avg_sq")


def reload_adam_states(optimizer, device, non_blocking: bool = False):
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""

def move_back_key(state, key):
state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking)

for _, state in optimizer.state.items():
if "exp_avg" in state:
move_back_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_back_key(state, "exp_avg_sq")


def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]:
"""Retrieve the devices of the specified state of the model.
Args:
model (DeepSpeedEngine): The model whose device allocations are to be checked.
state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved.
Returns:
Set[torch.device]: A set of devices of the specified state.
"""
if state == OffloadStateTypeEnum.hp_params:
return set(safe_get_local_fp32_param(p).device for p in model.parameters())
elif state == OffloadStateTypeEnum.lp_params:
return set(p.ds_tensor.device for p in model.parameters())
elif state == OffloadStateTypeEnum.lp_grads:
return {model.optimizer.grad_partitions_flat_buffer.device}
elif state == OffloadStateTypeEnum.optim_states:
return set(safe_get_local_optimizer_state(p, "exp_avg").device for p in model.parameters()) | \
set(safe_get_local_optimizer_state(p, "exp_avg_sq").device for p in model.parameters())
elif state == OffloadStateTypeEnum.contiguous_grad_buffer:
if model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer == None:
return {}
return {model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer.device}
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, reload_adam_states
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper
Expand Down
16 changes: 16 additions & 0 deletions docs/code-docs/source/zero3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,19 @@ Below is an example code snippet demonstrating how to offload FP32 parameters an
...
# Load states back to device memory
ds_engine.reload_states()
``deepspeed.runtime.zero.offload_states.get_state_devices`` returns devices of the specified state.

.. code-block:: python
def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]:
"""Retrieve the devices of the specified state of the model.
Args:
model (DeepSpeedEngine): The model whose device allocations are to be checked.
state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved.
Returns:
Set[torch.device]: A set of devices of the specified state.
"""
23 changes: 13 additions & 10 deletions tests/unit/runtime/zero/test_offload_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@
import deepspeed
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_optimizer_state
from deepspeed.runtime.zero.offload_states import get_state_devices


def validate_device(model, device: torch.device, include) -> None:
# Make sure the model parameters are offloaded
if include is None or OffloadStateTypeEnum.hp_params in include:
assert all(safe_get_local_fp32_param(p).device == device for p in model.parameters())
if include is None or OffloadStateTypeEnum.lp_params in include:
assert all(p.ds_tensor.device == device for p in model.parameters())
if include is None or OffloadStateTypeEnum.lp_grads in include:
assert model.optimizer.grad_partitions_flat_buffer.device == device
if include is None or OffloadStateTypeEnum.optim_states in include:
assert all(safe_get_local_optimizer_state(p, "exp_avg").device == device for p in model.parameters())
assert all(safe_get_local_optimizer_state(p, "exp_avg_sq").device == device for p in model.parameters())

def compare_device(state) -> bool:
devices = get_state_devices(model, state)
return len(devices) == 1 and device in devices

for state in OffloadStateTypeEnum:
if include is None or state in include:
if state == OffloadStateTypeEnum.contiguous_grad_buffer and device == torch.device("cpu"):
assert len(get_state_devices(model,
state)) == 0, f"State {state} must be removed after offload_states()"
else:
assert compare_device(state), f"State {state} is not on device {device}"


def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking):
Expand Down

0 comments on commit adec991

Please sign in to comment.