Skip to content

Commit

Permalink
Merge branch 'master' into onebit
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored May 14, 2024
2 parents 4a1dd4c + 62ca317 commit 0cb887a
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 106 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so
include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so
recursive-include requirements *.txt
recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc
recursive-include op_builder *.py
recursive-include benchmarks *.py
recursive-include accelerator *.py
11 changes: 9 additions & 2 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


Expand Down Expand Up @@ -133,7 +133,8 @@ def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm"
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -306,6 +307,8 @@ def tp_parser(model):
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attn.dense' in layer and 'Phi' in str(type(module)):
gem_list = gem_list + [layer]

layer_list = []
if gem_list != []:
Expand All @@ -328,6 +331,10 @@ def _replace(self, child, name, conv_linear_layer):
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate":
return child
# for phi3.
if 'gate_up_proj' in name:
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]
Expand Down
40 changes: 39 additions & 1 deletion deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads


def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
Expand Down Expand Up @@ -42,6 +42,7 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index):
"FalconDecoderLayer": 'bloomtype',
"GPTBigCodeBlock": 'bigcodetype',
"DecoderLayer": 'glmtype',
"Phi3DecoderLayer": "phi3type"
}

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
Expand Down Expand Up @@ -93,6 +94,20 @@ def _bigcode_type_transpose(input, mp_size):
split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], kv), dim=0)

def _phi3_type_transpose(input, mp_size):
num_kv_heads = get_num_kv_heads()
num_heads = get_num_attention_heads()
hidden_size = input.shape[1]
head_dim = hidden_size // num_heads
q_pos = input.shape[0] - 2 * num_kv_heads * head_dim
q = input[:q_pos]
k = input[q_pos:q_pos + num_kv_heads * head_dim]
v = input[q_pos + num_kv_heads * head_dim:]
split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0)
split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0)
split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0)

def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):

# suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
Expand All @@ -110,6 +125,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
return _qwen_type_transpose(src, mp_size, module)
elif fused_qkv_type == 'bigcodetype':
return _bigcode_type_transpose(src, mp_size)
elif fused_qkv_type == 'phi3type':
return _phi3_type_transpose(src, mp_size)

raise ValueError("unknown fused_qkv_type")

Expand All @@ -123,3 +140,24 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
return _bloom_type_transpose(src, mp_size)


# For phi3 with chunk mlp, adjust the weight order.
def shard_chunk_mlp(
weight,
bias,
rank,
world_size,
):
weight_gate, weight_states = weight.chunk(2, dim=0)
total_size = weight_gate.shape[0]
split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0)
if bias is not None:
bias_gate, bias_states = bias.chunk(2, dim=0)
split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0)

return shard_weight, None
6 changes: 5 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -290,6 +290,10 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
# 4.2 set n_embd
set_n_embd(n_embd)

# 4.3 set attention_heads
if hasattr(model_config, 'num_attention_heads'):
set_num_attention_heads(getattr(model_config, 'num_attention_heads'))

# 5. Set linear policies
_autotp.update_linear_policies()

Expand Down
10 changes: 10 additions & 0 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def set_num_kv_heads(num):
num_kv_heads = num


def set_num_attention_heads(num):
global num_attention_heads
num_attention_heads = num


def set_n_embd(num):
global n_embd
n_embd = num
Expand All @@ -22,6 +27,11 @@ def get_num_kv_heads():
return num_kv_heads


def get_num_attention_heads():
global num_attention_heads
return num_attention_heads


def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
Expand Down
163 changes: 82 additions & 81 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,84 +83,85 @@ def validate_enabled(cls, field_value, values):
return field_value


class CompiledModuleWrapper(torch.nn.Module):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
super().__init__()

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

modules = self.__dict__.get('_modules')
modules['wrapped'] = module
self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def __getattr__(self, name):
return getattr(self.__dict__['wrapped'], name)

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.
Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.
Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.
Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)
engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn
def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None):

class wrapper(mod.__class__):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
self.__dict__ = module.__dict__.copy()

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.
Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.
Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.
Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)
engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped,
backend=self._backend,
**self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn

return wrapper(mod, compile_config)
7 changes: 0 additions & 7 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,6 @@ def __getattr__(self, name):
return getattr(self, name)
elif name in dir(_module):
return getattr(_module, name)
elif isinstance(_module, CompiledModuleWrapper):
try:
return getattr(_module, name)
except AttributeError:
raise AttributeError(
f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'"
)
else:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

Expand Down
4 changes: 1 addition & 3 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ class PipelineEngine(DeepSpeedEngine):

def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule) \
or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \
"model must base PipelineModule"
assert isinstance(self.module, PipelineModule), "model must base PipelineModule"

assert self.zero_optimization_stage(
) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
Expand Down
Loading

0 comments on commit 0cb887a

Please sign in to comment.