diff --git a/docs/superbench-config.mdx b/docs/superbench-config.mdx index 102b8d69f..051abeda3 100644 --- a/docs/superbench-config.mdx +++ b/docs/superbench-config.mdx @@ -328,7 +328,8 @@ A list of models to run, only supported in model-benchmark. shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 | squeezenet1_0 | squeezenet1_1 | vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 | - bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl ] + bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl | + llama2-7b | llama2-13b | llama2-70b ] ``` * default value: `[ ]` diff --git a/docs/user-tutorial/benchmarks/model-benchmarks.md b/docs/user-tutorial/benchmarks/model-benchmarks.md index 34fdf4c70..71e8832cf 100644 --- a/docs/user-tutorial/benchmarks/model-benchmarks.md +++ b/docs/user-tutorial/benchmarks/model-benchmarks.md @@ -13,6 +13,7 @@ id: model-benchmarks Run training or inference tasks with single or half precision for deep learning models, including the following categories: * GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl +* LLAMA: llama2-7b, llama2-13b, llama2-70b * BERT: bert-base and bert-large * LSTM * CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including: diff --git a/examples/benchmarks/pytorch_llama2.py b/examples/benchmarks/pytorch_llama2.py new file mode 100644 index 000000000..2290ba1a5 --- /dev/null +++ b/examples/benchmarks/pytorch_llama2.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Model benchmark example for Llama2-7b (32-layer, 4096-hidden, 32-heads, 7B parameters). + +Commands to run: + python3 examples/benchmarks/pytorch_llama2.py (Single GPU) + python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_llama2.py \ + --distributed (Distributed) +""" + +import argparse + +from superbench.benchmarks import Platform, Framework, BenchmarkRegistry +from superbench.common.utils import logger + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--distributed', action='store_true', default=False, help='Whether to enable distributed training.' + ) + args = parser.parse_args() + + # Specify the model name and benchmark parameters. + model_name = 'llama2-7b' + parameters = '--batch_size 1 --duration 120 --seq_len 512 --precision float16' + if args.distributed: + parameters += ' --distributed_impl ddp --distributed_backend nccl' + + # Create context for Llama2 benchmark and run it for 120 seconds. + context = BenchmarkRegistry.create_benchmark_context( + model_name, platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH + ) + + benchmark = BenchmarkRegistry.launch_benchmark(context) + if benchmark: + logger.info( + 'benchmark: {}, return code: {}, result: {}'.format( + benchmark.name, benchmark.return_code, benchmark.result + ) + ) diff --git a/setup.py b/setup.py index cf9779a08..c6e2d1fe3 100644 --- a/setup.py +++ b/setup.py @@ -209,9 +209,10 @@ def run(self): 'yapf==0.31.0', ], 'torch': [ + 'tokenizers<=0.20.3', 'torch>=1.7.0a0', 'torchvision>=0.8.0a0', - 'transformers>=4.3.3, <4.23.0', + 'transformers>=4.28.0', ], 'ort': [ 'onnx>=1.10.2', diff --git a/superbench/benchmarks/base.py b/superbench/benchmarks/base.py index 014103744..8e6e58bfe 100644 --- a/superbench/benchmarks/base.py +++ b/superbench/benchmarks/base.py @@ -89,7 +89,8 @@ def get_configurable_settings(self): Return: All configurable settings in raw string. """ - return self._parser.format_help().strip() + message = self._parser.format_help().strip() + return message def parse_args(self, ignore_invalid=False): """Parse the arguments. diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index 1e37b793d..0f28f4f6a 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -9,11 +9,12 @@ import torch.hub import torch.onnx import torchvision.models -from transformers import BertConfig, GPT2Config +from transformers import BertConfig, GPT2Config, LlamaConfig from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel +from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel class torch2onnxExporter(): @@ -87,6 +88,39 @@ def __init__(self): ), self.num_classes, ), + 'llama2-7b': + lambda: LlamaBenchmarkModel( + LlamaConfig( + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + intermediate_size=11008, + ), + self.num_classes, + ), + 'llama2-13b': + lambda: LlamaBenchmarkModel( + LlamaConfig( + hidden_size=5120, + num_hidden_layers=40, + num_attention_heads=40, + num_key_value_heads=40, + intermediate_size=13824, + ), + self.num_classes, + ), + 'llama2-70b': + lambda: LlamaBenchmarkModel( + LlamaConfig( + hidden_size=8192, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + intermediate_size=28672, + ), + self.num_classes, + ), } self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx' self._onnx_model_path.mkdir(parents=True, exist_ok=True) @@ -138,7 +172,7 @@ def export_torchvision_model(self, model_name, batch_size=1): model, dummy_input, file_name, - opset_version=10, + opset_version=14, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, input_names=['input'], output_names=['output'], @@ -179,7 +213,7 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512): model, dummy_input, file_name, - opset_version=10, + opset_version=14, do_constant_folding=True, input_names=['input'], output_names=['output'], diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index eda0c4985..0829c4d33 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -10,4 +10,4 @@ from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT -__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT'] +__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] diff --git a/superbench/benchmarks/model_benchmarks/pytorch_llama.py b/superbench/benchmarks/model_benchmarks/pytorch_llama.py new file mode 100644 index 000000000..7161aeb83 --- /dev/null +++ b/superbench/benchmarks/model_benchmarks/pytorch_llama.py @@ -0,0 +1,265 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Module of the Pytorch Llama2 model.""" + +import torch +from transformers import LlamaModel, LlamaConfig +try: + import transformer_engine.pytorch as te + from transformer_engine.common.recipe import Format, DelayedScaling +except ImportError: + te = None + +from superbench.common.utils import logger +from superbench.benchmarks import BenchmarkRegistry, Precision +from superbench.benchmarks.model_benchmarks.model_base import Optimizer +from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase +from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset + + +class LlamaBenchmarkModel(torch.nn.Module): + """The Llama model for benchmarking.""" + def __init__(self, config, num_classes): + """Constructor. + + Args: + config (LlamaConfig): Configurations of Llama model. + num_classes (int): The number of objects for classification. + """ + super().__init__() + self._llama = LlamaModel(config) + self._linear = torch.nn.Linear(config.hidden_size, num_classes) + + def forward(self, input): + """Forward propagation function. + + Args: + input (torch.LongTensor): Indices of input sequence tokens in the vocabulary, + shape (batch_size, sequence_length). + + Return: + result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence + (classification token) further processed by a Linear layer, shape (batch_size, hidden_size). + """ + outputs = self._llama(input) + result = self._linear(outputs[0]) + return result + + +class PytorchLlama(PytorchBase): + """The Llama benchmark class.""" + def __init__(self, name, parameters=''): + """Constructor. + + Args: + name (str): benchmark name. + parameters (str): benchmark parameters. + """ + super().__init__(name, parameters) + self._config = None + self._fp8_recipe = None + self._supported_precision = [ + Precision.FLOAT32, + Precision.FLOAT16, + Precision.FP8_HYBRID, + Precision.FP8_E4M3, + ] + self._optimizer_type = Optimizer.ADAMW + self._loss_fn = torch.nn.CrossEntropyLoss() + + def add_parser_arguments(self): + """Add the Llama-specified arguments. + + Llama2 model reference: https://huggingface.co/docs/transformers/model_doc/llama2 + """ + super().add_parser_arguments() + + self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.') + self._parser.add_argument('--hidden_size', type=int, default=1280, required=False, help='Hidden size.') + self._parser.add_argument( + '--num_hidden_layers', type=int, default=36, required=False, help='The number of hidden layers.' + ) + self._parser.add_argument( + '--num_attention_heads', type=int, default=20, required=False, help='The number of attention heads.' + ) + self._parser.add_argument( + '--intermediate_size', + type=int, + default=11008, + required=False, + help='Dimension of the MLP representations.' + ) + self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.') + self._parser.add_argument( + '--num_key_value_heads', + type=int, + default=None, + required=False, + help='The number of key_value heads that should be used to implement Grouped Query Attention.' + ) + + def _generate_dataset(self): + """Generate dataset for benchmarking according to shape info. + + Return: + True if dataset is created successfully. + """ + self._dataset = TorchRandomDataset( + [self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long + ) + if len(self._dataset) == 0: + logger.error('Generate random dataset failed - model: {}'.format(self._name)) + return False + + return True + + def _create_model(self, precision): + """Construct the model for benchmarking. + + Args: + precision (Precision): precision of model and input data, such as float32, float16. + """ + self._config = LlamaConfig( + hidden_size=self._args.hidden_size, + num_hidden_layers=self._args.num_hidden_layers, + num_attention_heads=self._args.num_attention_heads, + num_key_value_heads=self._args.num_key_value_heads, + intermediate_size=self._args.intermediate_size, + max_position_embeddings=4096, # Maximum sequence length that llama2 supports + rms_norm_eps=1e-05, # Llama2 default for epsilon used by the rms normalization layers + ) + + enable_fp8 = precision.name.startswith('FP8_') + if enable_fp8 and te is None: + logger.error( + f'Create model with fp8 failed - model: {self._name}, precision: {precision},' + ' message: Cannot find transformer_engine.' + ) + return False + if enable_fp8 and not self._gpu_available: + logger.error( + f'Create model with fp8 failed - model: {self._name}, precision: {precision},' + ' message: FP8 is only supported on GPU.' + ) + return False + + try: + self._model = LlamaBenchmarkModel(self._config, self._args.num_classes) + if enable_fp8: + self._fp8_recipe = DelayedScaling( + fp8_format=Format[precision.name.strip('FP8_')], + amax_history_len=16, + amax_compute_algo='max', + ) + self._to_te_model(self._model.to(dtype=torch.float16)) + else: + self._model = self._model.to(dtype=getattr(torch, precision.value)) + if self._gpu_available: + self._model = self._model.cuda() + except BaseException as e: + logger.error( + 'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format( + self._name, precision, str(e) + ) + ) + return False + + self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes) + if self._gpu_available: + self._target = self._target.cuda() + + return True + + def _train_step(self, precision): + """Define the training process. + + Args: + precision (Precision): precision of model and input data, such as float32, float16. + + Return: + The step-time list of every training step. + """ + duration = [] + curr_step = 0 + check_frequency = 100 + while True: + for idx, sample in enumerate(self._dataloader): + start = self._timer() + if self._gpu_available: + sample = sample.cuda() + self._optimizer.zero_grad() + if self._fp8_recipe is not None: + with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe): + output = self._model(sample) + else: + output = self._model(sample) + loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target) + loss.backward() + self._optimizer.step() + end = self._timer() + curr_step += 1 + if curr_step > self._args.num_warmup: + # Save the step time of every training/inference step, unit is millisecond. + duration.append((end - start) * 1000) + self._log_step_time(curr_step, precision, duration) + if self._is_finished(curr_step, end, check_frequency): + return duration + + def _inference_step(self, precision): + """Define the inference process. + + Args: + precision (Precision): precision of model and input data, + such as float32, float16. + + Return: + The latency list of every inference operation. + """ + duration = [] + curr_step = 0 + with torch.no_grad(): + self._model.eval() + while True: + for idx, sample in enumerate(self._dataloader): + start = self._timer() + if self._gpu_available: + sample = sample.cuda() + if self._fp8_recipe is not None: + with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe): + self._model(sample) + else: + self._model(sample) + end = self._timer() + curr_step += 1 + if curr_step > self._args.num_warmup: + # Save the step time of every training/inference step, unit is millisecond. + duration.append((end - start) * 1000) + self._log_step_time(curr_step, precision, duration) + if self._is_finished(curr_step, end): + return duration + + +# Register Llama2 benchmark with 7b parameters. +BenchmarkRegistry.register_benchmark( + 'pytorch-llama2-7b', + PytorchLlama, + parameters='--hidden_size=4096 --num_hidden_layers=32 --num_attention_heads=32 --num_key_value_heads=32 \ + --intermediate_size=11008' +) + +# Register Llama2 benchmark with 13b parameters. +BenchmarkRegistry.register_benchmark( + 'pytorch-llama2-13b', + PytorchLlama, + parameters='--hidden_size=5120 --num_hidden_layers=40 --num_attention_heads=40 --num_key_value_heads=40 \ + --intermediate_size=13824' +) + +# Register Llama2 benchmark with 70b parameters. +BenchmarkRegistry.register_benchmark( + 'pytorch-llama2-70b', + PytorchLlama, + parameters='--hidden_size=8192 --num_hidden_layers=80 --num_attention_heads=64 --num_key_value_heads=8 \ + --intermediate_size=28672' +) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_llama.py b/tests/benchmarks/model_benchmarks/test_pytorch_llama.py new file mode 100644 index 000000000..a9a03d7b9 --- /dev/null +++ b/tests/benchmarks/model_benchmarks/test_pytorch_llama.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for Llama model benchmarks.""" + +from tests.helper import decorator +from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode +from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama + + +@decorator.cuda_test +@decorator.pytorch_test +def test_pytorch_llama_7b(): + """Test pytorch-llama2-7b benchmark for fp16 train and inference.""" + context = BenchmarkRegistry.create_benchmark_context( + 'llama2-7b', + platform=Platform.CUDA, + parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision float16 \ + --model_action train inference', + framework=Framework.PYTORCH + ) + + assert (BenchmarkRegistry.is_benchmark_context_valid(context)) + + benchmark = BenchmarkRegistry.launch_benchmark(context) + + # Check basic information. + assert (benchmark) + assert (isinstance(benchmark, PytorchLlama)) + assert (benchmark.name == 'pytorch-llama2-7b') + assert (benchmark.type == BenchmarkType.MODEL) + + # Check predefined parameters of llama2 7b model. + assert (benchmark._args.hidden_size == 4096) + assert (benchmark._args.num_hidden_layers == 32) + assert (benchmark._args.num_attention_heads == 32) + + # Check parameters specified in BenchmarkContext. + assert (benchmark._args.batch_size == 1) + assert (benchmark._args.num_classes == 100) + assert (benchmark._args.seq_len == 32) + assert (benchmark._args.num_warmup == 1) + assert (benchmark._args.num_steps == 2) + + # Test Dataset. + assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size) + + # Check results and metrics. + assert (benchmark.run_count == 1) + assert (benchmark.return_code == ReturnCode.SUCCESS) + + for metric in [ + 'fp16_train_step_time', 'fp16_train_throughput', 'fp16_inference_step_time', 'fp16_inference_throughput' + ]: + assert (len(benchmark.raw_data[metric]) == benchmark.run_count) + assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps) + assert (len(benchmark.result[metric]) == benchmark.run_count)