Skip to content

Commit

Permalink
Merge branch 'master' into DeepSpeedCheckpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored May 15, 2024
2 parents c5e5ade + 488a823 commit c30711d
Show file tree
Hide file tree
Showing 21 changed files with 532 additions and 116 deletions.
4 changes: 2 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ include *.txt README.md
include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so
include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so
recursive-include requirements *.txt
recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
recursive-include deepspeed *.cpp *.h *.hpp *.cu *.hip *.tr *.cuh *.cc *.json
recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc
recursive-include op_builder *.py
recursive-include benchmarks *.py
recursive-include accelerator *.py
11 changes: 9 additions & 2 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


Expand Down Expand Up @@ -133,7 +133,8 @@ def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm"
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -306,6 +307,8 @@ def tp_parser(model):
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attn.dense' in layer and 'Phi' in str(type(module)):
gem_list = gem_list + [layer]

layer_list = []
if gem_list != []:
Expand All @@ -328,6 +331,10 @@ def _replace(self, child, name, conv_linear_layer):
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate":
return child
# for phi3.
if 'gate_up_proj' in name:
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]
Expand Down
40 changes: 39 additions & 1 deletion deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads


def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
Expand Down Expand Up @@ -42,6 +42,7 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index):
"FalconDecoderLayer": 'bloomtype',
"GPTBigCodeBlock": 'bigcodetype',
"DecoderLayer": 'glmtype',
"Phi3DecoderLayer": "phi3type"
}

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
Expand Down Expand Up @@ -93,6 +94,20 @@ def _bigcode_type_transpose(input, mp_size):
split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], kv), dim=0)

def _phi3_type_transpose(input, mp_size):
num_kv_heads = get_num_kv_heads()
num_heads = get_num_attention_heads()
hidden_size = input.shape[1]
head_dim = hidden_size // num_heads
q_pos = input.shape[0] - 2 * num_kv_heads * head_dim
q = input[:q_pos]
k = input[q_pos:q_pos + num_kv_heads * head_dim]
v = input[q_pos + num_kv_heads * head_dim:]
split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0)
split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0)
split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0)

def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):

# suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
Expand All @@ -110,6 +125,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
return _qwen_type_transpose(src, mp_size, module)
elif fused_qkv_type == 'bigcodetype':
return _bigcode_type_transpose(src, mp_size)
elif fused_qkv_type == 'phi3type':
return _phi3_type_transpose(src, mp_size)

raise ValueError("unknown fused_qkv_type")

Expand All @@ -123,3 +140,24 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
return _bloom_type_transpose(src, mp_size)


# For phi3 with chunk mlp, adjust the weight order.
def shard_chunk_mlp(
weight,
bias,
rank,
world_size,
):
weight_gate, weight_states = weight.chunk(2, dim=0)
total_size = weight_gate.shape[0]
split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0)
if bias is not None:
bias_gate, bias_states = bias.chunk(2, dim=0)
split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0)

return shard_weight, None
6 changes: 5 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -290,6 +290,10 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
# 4.2 set n_embd
set_n_embd(n_embd)

# 4.3 set attention_heads
if hasattr(model_config, 'num_attention_heads'):
set_num_attention_heads(getattr(model_config, 'num_attention_heads'))

# 5. Set linear policies
_autotp.update_linear_policies()

Expand Down
10 changes: 10 additions & 0 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def set_num_kv_heads(num):
num_kv_heads = num


def set_num_attention_heads(num):
global num_attention_heads
num_attention_heads = num


def set_n_embd(num):
global n_embd
n_embd = num
Expand All @@ -22,6 +27,11 @@ def get_num_kv_heads():
return num_kv_heads


def get_num_attention_heads():
global num_attention_heads
return num_attention_heads


def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
Expand Down
92 changes: 92 additions & 0 deletions deepspeed/monitor/comet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import TYPE_CHECKING, Any, Tuple, List, Dict, Optional

from .utils import check_comet_availability
from .monitor import Monitor

import deepspeed.comm as dist

if TYPE_CHECKING:
import comet_ml
from .config import CometConfig

Name = str
Value = Any
GlobalSamples = int
Event = Tuple[Name, Value, GlobalSamples]


class CometMonitor(Monitor):

def __init__(self, comet_config: "CometConfig"):
super().__init__(comet_config)
check_comet_availability()
import comet_ml

self.enabled = comet_config.enabled
self._samples_log_interval = comet_config.samples_log_interval
self._experiment: Optional["comet_ml.ExperimentBase"] = None

if self.enabled and dist.get_rank() == 0:
self._experiment = comet_ml.start(
api_key=comet_config.api_key,
project=comet_config.project,
workspace=comet_config.workspace,
experiment_key=comet_config.experiment_key,
mode=comet_config.mode,
online=comet_config.online,
)

if comet_config.experiment_name is not None:
self._experiment.set_name(comet_config.experiment_name)

self._events_log_scheduler = EventsLogScheduler(comet_config.samples_log_interval)

@property
def experiment(self) -> Optional["comet_ml.ExperimentBase"]:
return self._experiment

@property
def samples_log_interval(self) -> int:
return self._samples_log_interval

def write_events(self, event_list: List[Event]) -> None:
if not self.enabled or dist.get_rank() != 0:
return None

for event in event_list:
name = event[0]
value = event[1]
engine_global_samples = event[2]

if self._events_log_scheduler.needs_logging(name, engine_global_samples):
self._experiment.__internal_api__log_metric__(
name=name,
value=value,
step=engine_global_samples,
)


class EventsLogScheduler:

def __init__(self, samples_log_interval: int):
self._samples_log_interval = samples_log_interval
self._last_logged_events_samples: Dict[str, int] = {}

def needs_logging(self, name: str, current_sample: int) -> bool:
if name not in self._last_logged_events_samples:
self._last_logged_events_samples[name] = current_sample
return True

last_logged_sample = self._last_logged_events_samples[name]
samples_delta = current_sample - last_logged_sample

if samples_delta >= self._samples_log_interval:
self._last_logged_events_samples[name] = current_sample
return True

return False
69 changes: 67 additions & 2 deletions deepspeed/monitor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

# DeepSpeed Team

from typing import Optional

from deepspeed.pydantic_v1 import root_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


def get_monitor_config(param_dict):
monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor")}
monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor", "comet")}
return DeepSpeedMonitorConfig(**monitor_dict)


Expand Down Expand Up @@ -60,12 +62,75 @@ class CSVConfig(DeepSpeedConfigModel):
""" Name for the current job. This will become a new directory inside `output_path`. """


class CometConfig(DeepSpeedConfigModel):
"""
Sets parameters for Comet monitor. For logging data Comet uses
experiment object.
https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/
"""

enabled: bool = False
""" Whether logging to Comet is enabled. Requires `comet_ml` package is installed. """

samples_log_interval: int = 100
""" Metrics will be submitted to Comet after processing every `samples_log_intervas` samples"""

project: Optional[str] = None
"""
Comet project name. Can be set through .comet.config file or environment variable COMET_PROJECT_NAME
https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options
"""

workspace: Optional[str] = None
"""
Comet workspace name. Can be set through .comet.config file or environment variable COMET_WORKSPACE
https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options
"""

api_key: Optional[str] = None
"""
Comet API key. Can be set through .comet.config file or environment variable COMET_API_KEY
https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options
"""

experiment_name: Optional[str] = None
"""
The name for comet experiment to be used for logging.
Can be set through .comet.config file or environment variable COMET_EXPERIMENT_NAME
https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options
"""

experiment_key: Optional[str] = None
"""
The key for comet experiment to be used for logging. Must be an alphanumeric string whose length is between 32 and 50 characters.
Can be set through .comet.config or environment variable COMET_EXPERIMENT_KEY
https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options
"""

online: Optional[bool] = None
"""
If True, the data will be logged to Comet server, otherwise it will be stored locally in offline experiment
Defaults to True.
"""

mode: Optional[str] = None
"""
Control how the Comet experiment is started, 3 options are possible.:
- "get": Continue logging to an existing experiment identified by the `experiment_key` value.
- "create": Always creates of a new experiment, useful for HPO sweeps.
- "get_or_create" (default): Starts a fresh experiment if required, or persists logging to an existing one.
"""


class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
"""Sets parameters for various monitoring methods."""

tensorboard: TensorBoardConfig = {}
""" TensorBoard monitor, requires `tensorboard` package is installed. """

comet: CometConfig = {}
""" Comet monitor, requires `comet_ml` package is installed """

wandb: WandbConfig = {}
""" WandB monitor, requires `wandb` package is installed. """

Expand All @@ -75,5 +140,5 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
@root_validator
def check_enabled(cls, values):
values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get(
"csv_monitor").enabled
"csv_monitor").enabled or values.get("comet")
return values
6 changes: 6 additions & 0 deletions deepspeed/monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def write_events(self, event_list):
from .wandb import WandbMonitor
from .tensorboard import TensorBoardMonitor
from .csv_monitor import csvMonitor
from .comet import CometMonitor


class MonitorMaster(Monitor):
Expand All @@ -33,6 +34,7 @@ def __init__(self, monitor_config):
self.tb_monitor = None
self.wandb_monitor = None
self.csv_monitor = None
self.comet_monitor = None
self.enabled = monitor_config.enabled

if dist.get_rank() == 0:
Expand All @@ -42,6 +44,8 @@ def __init__(self, monitor_config):
self.wandb_monitor = WandbMonitor(monitor_config.wandb)
if monitor_config.csv_monitor.enabled:
self.csv_monitor = csvMonitor(monitor_config.csv_monitor)
if monitor_config.comet.enabled:
self.comet_monitor = CometMonitor(monitor_config.comet)

def write_events(self, event_list):
if dist.get_rank() == 0:
Expand All @@ -51,3 +55,5 @@ def write_events(self, event_list):
self.wandb_monitor.write_events(event_list)
if self.csv_monitor is not None:
self.csv_monitor.write_events(event_list)
if self.comet_monitor is not None:
self.comet_monitor.write_events(event_list)
Loading

0 comments on commit c30711d

Please sign in to comment.