diff --git a/README.md b/README.md index 694ff48d361..4f07a63f764 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,7 @@ Demonstration - Flexible network architecture thanks to Chainer and PyTorch - Flexible front-end processing thanks to [kaldiio](https://github.com/nttcslab-sp/kaldiio) and HDF5 support - Tensorboard-based monitoring +- [DeepSpeed](https://github.com/microsoft/DeepSpeed)-based large-scale training ### ESPnet2 See [ESPnet2](https://espnet.github.io/espnet/espnet2_tutorial.html). diff --git a/egs2/an4/asr1/conf/deepspeed_zero2.json b/egs2/an4/asr1/conf/deepspeed_zero2.json new file mode 100644 index 00000000000..060f599c42e --- /dev/null +++ b/egs2/an4/asr1/conf/deepspeed_zero2.json @@ -0,0 +1,39 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "gradient_clipping": 1.0, + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 2, + "contiguous_gradients": true, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "allgather_bucket_size": 5e8 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [ + 0.9, + 0.95 + ], + "eps": 1e-8, + "weight_decay": 3e-7, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.0001, + "warmup_num_steps": 30000 + } + }, + "wall_clock_breakdown": false, + "steps_per_print": 1000 + } diff --git a/egs2/an4/asr1/conf/train_asr_transformer_deepspeed.yaml b/egs2/an4/asr1/conf/train_asr_transformer_deepspeed.yaml new file mode 100644 index 00000000000..97f5a3a724f --- /dev/null +++ b/egs2/an4/asr1/conf/train_asr_transformer_deepspeed.yaml @@ -0,0 +1,64 @@ +# A toy example of how DeepSpeed is used in ESPnet. +# With DeepSpeed, users only need to specify the model- and dataloader-realted items. +# Other configs should be specified in deepspeed_config file, such as: +# * optimization +# * training dtype or automatic mixed precision (AMP) setup +# * gradient accumulation +# * gradient clip +# * model saving and loading +# * learning rate scheduler +# * ... +# +# With DeepSpeed, one can also use some advanced trainer features, such as: +# * ZeRO-1/2/3 optimization +# * parameter offload +# * activation checkpointing +# * ... +# So that a very large model can be trained easily. +# +# The provided conf/deepspeed_zero2.json only contains a simple use case of DeepSpeed. +# Based on model arch and cluster feature, advanced users are encouraged to tune the +# config file following the official documents: https://deepspeed.readthedocs.io/en/latest/ +# +# Note: the batch size-related setup is up to ESPnet dataloader settings rather than +# those specified in DeepSpeed config. +# +# Before training with DeepSpeed, make sure it has been installed. +# DeepSpeed will compile some torch extensions when you use them for the first time. So make +# sure you have ${CUDA_HOME} in your environment variables that contain a complete CUDA +# installation that is compatible with your pytorch CUDA. The compatibility requirement is +# only about the major CUDA version. E.g., CUDA 11.x are always compatible with each other. + +use_deepspeed: true +deepspeed_config: conf/deepspeed_zero2.json + +batch_type: folded +batch_size: 64 +max_epoch: 200 + +encoder: transformer +encoder_conf: + output_size: 256 + attention_heads: 4 + linear_units: 2048 + num_blocks: 12 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d + normalize_before: true + +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 + length_normalized_loss: false diff --git a/egs2/librispeech_100/asr1/conf/deepspeed_zero2.json b/egs2/librispeech_100/asr1/conf/deepspeed_zero2.json new file mode 100644 index 00000000000..d7d206bddca --- /dev/null +++ b/egs2/librispeech_100/asr1/conf/deepspeed_zero2.json @@ -0,0 +1,39 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "gradient_clipping": 1.0, + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 2, + "contiguous_gradients": true, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "allgather_bucket_size": 5e8 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [ + 0.9, + 0.95 + ], + "eps": 1e-8, + "weight_decay": 3e-7, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.0001, + "warmup_num_steps": 30000 + } + }, + "wall_clock_breakdown": false, + "steps_per_print": 1000 + } diff --git a/espnet2/asr/ctc.py b/espnet2/asr/ctc.py index 13621a47cf7..3c61cbee242 100644 --- a/espnet2/asr/ctc.py +++ b/espnet2/asr/ctc.py @@ -74,7 +74,7 @@ def __init__( def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor: if self.ctc_type == "builtin" or self.ctc_type == "brctc": - th_pred = th_pred.log_softmax(2) + th_pred = th_pred.log_softmax(2).float() loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) if self.ctc_type == "builtin": size = th_pred.size(1) @@ -91,7 +91,7 @@ def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor: # builtin2 ignores nan losses using the logic below, while # builtin relies on the zero_infinity flag in pytorch CTC elif self.ctc_type == "builtin2": - th_pred = th_pred.log_softmax(2) + th_pred = th_pred.log_softmax(2).float() loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) if loss.requires_grad and self.ignore_nan_grad: diff --git a/espnet2/layers/stft.py b/espnet2/layers/stft.py index 869c96dd29c..de466714c38 100644 --- a/espnet2/layers/stft.py +++ b/espnet2/layers/stft.py @@ -101,8 +101,9 @@ def forward( onesided=self.onesided, ) stft_kwargs["return_complex"] = True - output = torch.stft(input, **stft_kwargs) - output = torch.view_as_real(output) + # NOTE(Jinchuan) CuFFT is not compatible with bfloat16 + output = torch.stft(input.float(), **stft_kwargs) + output = torch.view_as_real(output).type(input.dtype) else: if self.training: raise NotImplementedError( diff --git a/espnet2/tasks/abs_task.py b/espnet2/tasks/abs_task.py index 1bfb47cec73..3f84fc6d1e1 100644 --- a/espnet2/tasks/abs_task.py +++ b/espnet2/tasks/abs_task.py @@ -289,7 +289,6 @@ def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel: @classmethod @typechecked def get_parser(cls) -> config_argparse.ArgumentParser: - class ArgumentDefaultsRawTextHelpFormatter( argparse.RawTextHelpFormatter, argparse.ArgumentDefaultsHelpFormatter, @@ -449,6 +448,18 @@ class ArgumentDefaultsRawTextHelpFormatter( type=str2bool, help="Enable sharded training provided by fairscale", ) + group.add_argument( + "--use_deepspeed", + default=False, + type=str2bool, + help="Enable deepspeed for training", + ) + group.add_argument( + "--deepspeed_config", + default=None, + type=str, + help="deepspeed training config", + ) group = parser.add_argument_group("cudnn mode related") group.add_argument( @@ -1529,6 +1540,23 @@ def main_worker(cls, args: argparse.Namespace): # Don't give args to trainer.run() directly!!! # Instead of it, define "Options" object and build here. + + if args.use_deepspeed: + if not distributed_option.distributed: + logging.warning( + "DeepSpeed is for distributed training. E.g., --ngpu > 1 " + "Switch back to the normal trainer." + ) + elif cls.trainer != Trainer: + raise ValueError( + "only default trainer is compatible with deepspeed" + ) + else: + from espnet2.train.deepspeed_trainer import DeepSpeedTrainer + + cls.trainer = DeepSpeedTrainer + distributed_option.init_deepspeed() + trainer_options = cls.trainer.build_options(args) cls.trainer.run( model=model, diff --git a/espnet2/torch_utils/device_funcs.py b/espnet2/torch_utils/device_funcs.py index 7919e7d9232..3abe51e432d 100644 --- a/espnet2/torch_utils/device_funcs.py +++ b/espnet2/torch_utils/device_funcs.py @@ -28,6 +28,18 @@ def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): elif isinstance(data, np.ndarray): return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy) elif isinstance(data, torch.Tensor): + if dtype is not None: + dtype = str(dtype).removeprefix("torch.") + cur_dtype = str(data.dtype).removeprefix("torch.") + + if not ( + ("int" in dtype and "int" in cur_dtype) + or ("float" in dtype and "float" in cur_dtype) + ): + dtype = None # avoid conversion between int and float. + else: + dtype = getattr(torch, dtype) + return data.to(device, dtype, non_blocking, copy) else: return data diff --git a/espnet2/train/deepspeed_trainer.py b/espnet2/train/deepspeed_trainer.py new file mode 100644 index 00000000000..667dba9ee37 --- /dev/null +++ b/espnet2/train/deepspeed_trainer.py @@ -0,0 +1,273 @@ +""" DeepSpeed Trainer Module """ + +import argparse +import dataclasses +import json +import logging + +import torch +import torch.distributed as dist + +try: + import deepspeed + from deepspeed import DeepSpeedEngine +except ImportError: + logging.warning("deepspeed is not installed") + deepspeed = None + DeepSpeedEngine = None + +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp +from typeguard import typechecked + +from espnet2.iterators.abs_iter_factory import AbsIterFactory +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.recursive_op import recursive_average +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet2.train.reporter import Reporter, SubReporter +from espnet2.train.trainer import Trainer +from espnet2.utils.build_dataclass import build_dataclass + + +@dataclasses.dataclass +class DeepSpeedTrainerOptions: + resume: bool + seed: int + train_dtype: Union[str, torch.dtype] + log_interval: Optional[int] + output_dir: Union[Path, str] + max_epoch: int + deepspeed_config: Union[Path, str] + + +class DeepSpeedTrainer(Trainer): + + @classmethod + @typechecked + def build_options(cls, args: argparse.Namespace) -> DeepSpeedTrainerOptions: + return build_dataclass(DeepSpeedTrainerOptions, args) + + @staticmethod + @typechecked + def resume( + model: DeepSpeedEngine, + reporter: Reporter, + output_dir: Path, + ): + ckpts = [ + item + for item in output_dir.iterdir() + if item.is_dir() and item.name.startswith("checkpoint_") + ] + + if len(ckpts) == 0: + logging.info("Try to resume but find no checkpoint") + return + + ckpt_num = max([int(item.name.split("_")[-1]) for item in ckpts]) + ckpt_path = output_dir / f"checkpoint_{ckpt_num}" + logging.info(f"Resume training from {ckpt_path}") + + _, clinet_states = model.load_checkpoint(ckpt_path) + + reporter.load_state_dict(clinet_states["reporter"]) + + @classmethod + @typechecked + def run( + cls, + model: Union[AbsESPnetModel, DeepSpeedEngine], + train_iter_factory: AbsIterFactory, + valid_iter_factory: AbsIterFactory, + trainer_options: DeepSpeedTrainerOptions, + **kwargs, + ) -> None: + + # (1) arguments needed in previous trainer but not this one. Delete them + del kwargs + + # (2) initailize deepspeed + if deepspeed is None: + raise ImportError("Cannot proceed as deepspeed is not installed") + deepspeed_config = json.load(open(trainer_options.deepspeed_config)) + trainer_options.train_dtype = cls.setup_data_dtype(deepspeed_config) + model, _, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=deepspeed_config, + ) + + # (3) setup reporter, output_dir, dataloader etc. + output_dir = Path(trainer_options.output_dir) + reporter = Reporter() + + # (4) resume + if trainer_options.resume: + cls.resume( + model=model, + reporter=reporter, + output_dir=output_dir, + ) + + # (5) loop on epochs + start_epoch = reporter.get_epoch() + 1 + if start_epoch == trainer_options.max_epoch + 1: + logging.warning( + f"The training has already reached at max_epoch: {start_epoch}" + ) + + for iepoch in range(start_epoch, trainer_options.max_epoch + 1): + set_all_random_seed(trainer_options.seed + iepoch) + reporter.set_epoch(iepoch) + + # (5.1) train one epoch + with reporter.observe("train") as sub_reporter: + cls.train_one_epoch( + model=model, + iterator=train_iter_factory.build_iter(iepoch), + reporter=sub_reporter, + options=trainer_options, + ) + + # (5.2) valid one epoch + with reporter.observe("valid") as sub_reporter: + cls.valid_one_epoch( + model=model, + iterator=valid_iter_factory.build_iter(iepoch), + reporter=sub_reporter, + options=trainer_options, + ) + + # (5.3) save checkpoint + checkpoint_path = output_dir / f"checkpoint_{iepoch}" + model.save_checkpoint( + checkpoint_path, + tag=f"{iepoch}", + client_state={"reporter": reporter.state_dict()}, + ) + + # (5.4) reporter + if dist.get_rank() == 0: + logging.info(reporter.log_message()) + reporter.matplotlib_plot(output_dir / "images") + + @classmethod + @typechecked + def train_one_epoch( + cls, + model, + iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], + reporter: SubReporter, + options: DeepSpeedTrainerOptions, + ) -> None: + model.train() + iterator_stop = torch.tensor(0).cuda() + + log_interval = options.log_interval + if log_interval is None: + try: + log_interval = max(len(iterator) // 20, 10) + except TypeError: + log_interval = 100 + + for iiter, (utt_id, batch) in enumerate( + reporter.measure_iter_time(iterator, "iter_time"), 1 + ): + assert isinstance(batch, dict), type(batch) + + with reporter.measure_time("step_time"): + # (0) ensure all ranks have not finished. + dist.all_reduce(iterator_stop, ReduceOp.SUM) + if iterator_stop > 0: + break + + # (1) forward + batch["utt_id"] = utt_id + batch = to_device(batch, "cuda", dtype=options.train_dtype) + loss, stats, weight = model(**batch) + + # (2) all-reduce statistics and logging on model side + stats = {k: v for k, v in stats.items() if v is not None} + stats, weight = recursive_average(stats, weight, True) + reporter.register(stats, weight) + + # (3) backward and logging on trainer side + loss = loss / weight * dist.get_world_size() + model.backward(loss) + model.step() + + reporter.register( + dict( + grad_norm=model.get_global_grad_norm(), + loss_scale=model.loss_scale(), + learning_rate=model.get_lr()[0], + ) + ) + + reporter.next() + if iiter % log_interval == 0: + logging.info(reporter.log_message(-log_interval)) + + else: + iterator_stop.fill_(1) + dist.all_reduce(iterator_stop, ReduceOp.SUM) + + @classmethod + @typechecked + @torch.no_grad() + def valid_one_epoch( + cls, + model, + iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], + reporter: SubReporter, + options: DeepSpeedTrainerOptions, + ) -> None: + model.eval() + iterator_stop = torch.tensor(0).cuda() + + for iiter, (utt_id, batch) in enumerate(iterator): + assert isinstance(batch, dict), type(batch) + + # (0) ensure all ranks have not finished. + dist.all_reduce(iterator_stop, ReduceOp.SUM) + if iterator_stop > 0: + break + + # (1) forward + batch["utt_id"] = utt_id + batch = to_device(batch, "cuda", dtype=options.train_dtype) + loss, stats, weight = model(**batch) + + # (2) all-reduce statistics and logging on model side + stats = {k: v for k, v in stats.items() if v is not None} + stats, weight = recursive_average(stats, weight, True) + + reporter.register(stats, weight) + reporter.next() + + else: + iterator_stop.fill_(1) + dist.all_reduce(iterator_stop, ReduceOp.SUM) + + @classmethod + @typechecked + def setup_data_dtype(cls, deepspeed_config: Dict): + if "bf16" in deepspeed_config: + return torch.bfloat16 + + elif "fp16" in deepspeed_config: + return torch.float16 + + elif "amp" in deepspeed_config: + if torch.cuda.is_bf16_supported(): + return torch.bfloat16 + else: + return torch.float16 + + else: + return torch.float diff --git a/espnet2/train/distributed_utils.py b/espnet2/train/distributed_utils.py index 8036d691979..1c8efec2022 100644 --- a/espnet2/train/distributed_utils.py +++ b/espnet2/train/distributed_utils.py @@ -1,4 +1,5 @@ import dataclasses +import logging import os import socket from typing import Optional @@ -108,6 +109,34 @@ def init_torch_distributed(self): if self.local_rank is not None and self.ngpu > 0: torch.cuda.set_device(self.local_rank) + def init_deepspeed(self): + try: + import deepspeed + except ImportError: + raise + + if not torch.distributed.is_initialized(): + raise ValueError( + "Should initailize torch distributed before initializing deepspeed" + ) + + # NOTE(Jinchuan): init torch distributed backend first. Then + # deepspeed will find that backend automatically. + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["RANK"] = str(self.dist_rank) + os.environ["WORLD_SIZE"] = str(self.dist_world_size) + if int(os.environ["OMP_NUM_THREADS"]) == 1: + logging.warning( + "\n=================================================================\n" + "Found OMP_NUM_THREADS=1 in environment variables. " + "With some advanced features, DeepSpeed may have heavy cpu workload " + "so that OMP_NUM_THREADS=1 is not sufficient. " + "Try to increase it in your path.sh \n" + "=================================================================" + ) + + deepspeed.init_distributed() + def resolve_distributed_mode(args): # Note that args.distributed is set by only this function. diff --git a/tools/extra_path.sh b/tools/extra_path.sh index ac32987733d..ca21b271a10 100644 --- a/tools/extra_path.sh +++ b/tools/extra_path.sh @@ -37,3 +37,16 @@ export PATH="${TOOL_DIR}"/ffmpeg-release:"${PATH:-}" export LD_LIBRARY_PATH="${TOOL_DIR}"/lib:"${TOOL_DIR}"/lib64:"${LD_LIBRARY_PATH:-}" export LD_LIBRARY_PATH="${TOOL_DIR}"/espeak-ng/lib:"${LD_LIBRARY_PATH:-}" export PYTHONPATH="${TOOL_DIR}"/RawNet/python/RawNet3:"${TOOL_DIR}"/RawNet/python/RawNet3/models:"${PYTHONPATH:-}" + +# DeepSpeed related. Users should set CUDA_HOME by themselves. +CUDA_HOME= +if [ -n "${CUDA_HOME}" ]; then + export LIBRARY_PATH=${CUDA_HOME}/lib64:${LIBRARY_PATH} + export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} + export PATH=${CUDA_HOME}/bin:${PATH} +fi + +if [ -n "${CONDA_PREFIX-}" ]; then + export CFLAGS="-I${CONDA_PREFIX}/include" + export LDFLAGS="-L${CONDA_PREFIX}/lib" +fi