Skip to content

Commit

Permalink
Merge branch 'master' into stage3-use-new-grad-acc-api
Browse files Browse the repository at this point in the history
  • Loading branch information
deepcharm authored Dec 18, 2024
2 parents 1f2dced + 4cd1d97 commit 4ebec9d
Show file tree
Hide file tree
Showing 16 changed files with 38 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cpu-torch-latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout 6c3f168b3
# git checkout 6c3f168b3
git rev-parse --short HEAD
pip install .
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/no-torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
workflow_dispatch:
pull_request:
paths:
- 'accelerator/**'
- '.github/workflows/no-torch.yml'
- 'op_builder/**'
schedule:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout 6c3f168b3
# git checkout 6c3f168b3
git rev-parse --short HEAD
pip install .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-torch-nightly-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout 6c3f168b3
# git checkout 6c3f168b3
git rev-parse --short HEAD
pip install .
Expand Down
8 changes: 7 additions & 1 deletion accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

# DeepSpeed Team

import torch
from .abstract_accelerator import DeepSpeedAccelerator

# During setup stage torch may not be installed, pass on no torch will
# allow op builder related API to be executed.
try:
import torch
except ImportError as e:
pass

try:
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
oneccl_imported_p = True
Expand Down
3 changes: 3 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel):
tp_size: int = 1
""" Number of devices to split the model across using tensor parallelism. """

tp_grain_size: int = 64
"Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size."

mpu: object = None
"""
A model parallelism unit object that implements
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def get_cmd(self, environment, active_resources):
deepspeed_launch.append("--no_local_rank")
if self.args.save_pid:
deepspeed_launch += ["--save_pid", f"{os.getpid()}"]
if self.args.enable_each_rank_log:
deepspeed_launch.append(f"--enable_each_rank_log={self.args.enable_each_rank_log}")
if self.args.elastic_training:
deepspeed_launch.append("--enable_elastic_training")
deepspeed_launch.append(f"--max_elastic_nodes={self.args.max_elastic_nodes}")
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -303,6 +303,9 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
if hasattr(model_config, 'num_attention_heads'):
set_num_attention_heads(getattr(model_config, 'num_attention_heads'))

# 4.4 set tp_grain_size
set_tp_grain_size(config.tensor_parallel.tp_grain_size)

# 5. Set linear policies
_autotp.update_linear_policies()

Expand Down
11 changes: 8 additions & 3 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def set_n_embd(num):
n_embd = num


def set_tp_grain_size(num):
global tp_grain_size
tp_grain_size = num


def get_num_kv_heads():
global num_kv_heads
if 'num_kv_heads' in globals():
Expand All @@ -45,9 +50,9 @@ def get_shard_size(total_size, mp_size, name=None, rank=None):
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
if total_size >= 64:
grain_size = total_size // 64
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64
if total_size >= tp_grain_size:
grain_size = total_size // tp_grain_size
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size
else:
return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)

Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3738,6 +3738,11 @@ def offload_states(self,
assert self.zero_optimization_stage(
) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3."

opt_offload_config = self.zero_offload_optimizer()
assert opt_offload_config is None or opt_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded optimizer states."
param_offload_config = self.zero_offload_param()
assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters."

assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters."

if device == OffloadDeviceEnum.none:
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def run_bias_geglu_reference(activations, bias):
# Expected behavior is that of casting to float32 internally
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def run_bias_gelu_reference(activations, bias):
# Expected behavior is that of casting to float32 internally and using the tanh approximation
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def run_bias_relu_reference(activations, bias):
# Expected behavior is that of casting to float32 internally
Expand Down
14 changes: 5 additions & 9 deletions tests/unit/ops/transformer/inference/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer import DeepSpeedInferenceConfig
from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp
from deepspeed.utils.torch import required_torch_version

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def allclose(x, y):
assert x.dtype == y.dtype
Expand All @@ -23,14 +22,11 @@ def allclose(x, y):


def version_appropriate_gelu(activations):
global torch_minor_version
if torch_minor_version is None:
torch_minor_version = int(torch.__version__.split('.')[1])
# If torch version = 1.12
if torch_minor_version < 12:
return torch.nn.functional.gelu(activations)
else:
# gelu behavior changes (correctly) in torch 1.12
if required_torch_version(min_version=1.12):
return torch.nn.functional.gelu(activations, approximate='tanh')
else:
return torch.nn.functional.gelu(activations)


def run_gelu_reference(activations):
Expand Down
1 change: 0 additions & 1 deletion tests/unit/ops/transformer/inference/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

inference_module = None
torch_minor_version = None


def allclose(x, y):
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def allclose(x, y):
assert x.dtype == y.dtype
Expand Down

0 comments on commit 4ebec9d

Please sign in to comment.