Skip to content

Commit

Permalink
Merge pull request espnet#5856 from jctian98/deepspeed
Browse files Browse the repository at this point in the history
Add DeepSpeed trainer for large-scale training
  • Loading branch information
sw005320 authored Aug 26, 2024
2 parents 38cc9e8 + 7e5f289 commit b54ea65
Show file tree
Hide file tree
Showing 11 changed files with 504 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ Demonstration
- Flexible network architecture thanks to Chainer and PyTorch
- Flexible front-end processing thanks to [kaldiio](https://github.com/nttcslab-sp/kaldiio) and HDF5 support
- Tensorboard-based monitoring
- [DeepSpeed](https://github.com/microsoft/DeepSpeed)-based large-scale training

### ESPnet2
See [ESPnet2](https://espnet.github.io/espnet/espnet2_tutorial.html).
Expand Down
39 changes: 39 additions & 0 deletions egs2/an4/asr1/conf/deepspeed_zero2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"gradient_clipping": 1.0,
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"contiguous_gradients": true,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [
0.9,
0.95
],
"eps": 1e-8,
"weight_decay": 3e-7,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.0001,
"warmup_num_steps": 30000
}
},
"wall_clock_breakdown": false,
"steps_per_print": 1000
}
64 changes: 64 additions & 0 deletions egs2/an4/asr1/conf/train_asr_transformer_deepspeed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# A toy example of how DeepSpeed is used in ESPnet.
# With DeepSpeed, users only need to specify the model- and dataloader-realted items.
# Other configs should be specified in deepspeed_config file, such as:
# * optimization
# * training dtype or automatic mixed precision (AMP) setup
# * gradient accumulation
# * gradient clip
# * model saving and loading
# * learning rate scheduler
# * ...
#
# With DeepSpeed, one can also use some advanced trainer features, such as:
# * ZeRO-1/2/3 optimization
# * parameter offload
# * activation checkpointing
# * ...
# So that a very large model can be trained easily.
#
# The provided conf/deepspeed_zero2.json only contains a simple use case of DeepSpeed.
# Based on model arch and cluster feature, advanced users are encouraged to tune the
# config file following the official documents: https://deepspeed.readthedocs.io/en/latest/
#
# Note: the batch size-related setup is up to ESPnet dataloader settings rather than
# those specified in DeepSpeed config.
#
# Before training with DeepSpeed, make sure it has been installed.
# DeepSpeed will compile some torch extensions when you use them for the first time. So make
# sure you have ${CUDA_HOME} in your environment variables that contain a complete CUDA
# installation that is compatible with your pytorch CUDA. The compatibility requirement is
# only about the major CUDA version. E.g., CUDA 11.x are always compatible with each other.

use_deepspeed: true
deepspeed_config: conf/deepspeed_zero2.json

batch_type: folded
batch_size: 64
max_epoch: 200

encoder: transformer
encoder_conf:
output_size: 256
attention_heads: 4
linear_units: 2048
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d
normalize_before: true

decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

model_conf:
ctc_weight: 0.3
lsm_weight: 0.1
length_normalized_loss: false
39 changes: 39 additions & 0 deletions egs2/librispeech_100/asr1/conf/deepspeed_zero2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"gradient_clipping": 1.0,
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"contiguous_gradients": true,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [
0.9,
0.95
],
"eps": 1e-8,
"weight_decay": 3e-7,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.0001,
"warmup_num_steps": 30000
}
},
"wall_clock_breakdown": false,
"steps_per_print": 1000
}
4 changes: 2 additions & 2 deletions espnet2/asr/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(

def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor:
if self.ctc_type == "builtin" or self.ctc_type == "brctc":
th_pred = th_pred.log_softmax(2)
th_pred = th_pred.log_softmax(2).float()
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
if self.ctc_type == "builtin":
size = th_pred.size(1)
Expand All @@ -91,7 +91,7 @@ def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor:
# builtin2 ignores nan losses using the logic below, while
# builtin relies on the zero_infinity flag in pytorch CTC
elif self.ctc_type == "builtin2":
th_pred = th_pred.log_softmax(2)
th_pred = th_pred.log_softmax(2).float()
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)

if loss.requires_grad and self.ignore_nan_grad:
Expand Down
5 changes: 3 additions & 2 deletions espnet2/layers/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ def forward(
onesided=self.onesided,
)
stft_kwargs["return_complex"] = True
output = torch.stft(input, **stft_kwargs)
output = torch.view_as_real(output)
# NOTE(Jinchuan) CuFFT is not compatible with bfloat16
output = torch.stft(input.float(), **stft_kwargs)
output = torch.view_as_real(output).type(input.dtype)
else:
if self.training:
raise NotImplementedError(
Expand Down
30 changes: 29 additions & 1 deletion espnet2/tasks/abs_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
@classmethod
@typechecked
def get_parser(cls) -> config_argparse.ArgumentParser:

class ArgumentDefaultsRawTextHelpFormatter(
argparse.RawTextHelpFormatter,
argparse.ArgumentDefaultsHelpFormatter,
Expand Down Expand Up @@ -449,6 +448,18 @@ class ArgumentDefaultsRawTextHelpFormatter(
type=str2bool,
help="Enable sharded training provided by fairscale",
)
group.add_argument(
"--use_deepspeed",
default=False,
type=str2bool,
help="Enable deepspeed for training",
)
group.add_argument(
"--deepspeed_config",
default=None,
type=str,
help="deepspeed training config",
)

group = parser.add_argument_group("cudnn mode related")
group.add_argument(
Expand Down Expand Up @@ -1529,6 +1540,23 @@ def main_worker(cls, args: argparse.Namespace):

# Don't give args to trainer.run() directly!!!
# Instead of it, define "Options" object and build here.

if args.use_deepspeed:
if not distributed_option.distributed:
logging.warning(
"DeepSpeed is for distributed training. E.g., --ngpu > 1 "
"Switch back to the normal trainer."
)
elif cls.trainer != Trainer:
raise ValueError(
"only default trainer is compatible with deepspeed"
)
else:
from espnet2.train.deepspeed_trainer import DeepSpeedTrainer

cls.trainer = DeepSpeedTrainer
distributed_option.init_deepspeed()

trainer_options = cls.trainer.build_options(args)
cls.trainer.run(
model=model,
Expand Down
12 changes: 12 additions & 0 deletions espnet2/torch_utils/device_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
elif isinstance(data, np.ndarray):
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
elif isinstance(data, torch.Tensor):
if dtype is not None:
dtype = str(dtype).removeprefix("torch.")
cur_dtype = str(data.dtype).removeprefix("torch.")

if not (
("int" in dtype and "int" in cur_dtype)
or ("float" in dtype and "float" in cur_dtype)
):
dtype = None # avoid conversion between int and float.
else:
dtype = getattr(torch, dtype)

return data.to(device, dtype, non_blocking, copy)
else:
return data
Expand Down
Loading

0 comments on commit b54ea65

Please sign in to comment.