Skip to content

Commit

Permalink
Benchmarks - Add LLaMA-2 Models (#668)
Browse files Browse the repository at this point in the history
Added llama benchmark - training and inference in accordance with the
existing pytorch models implementation like gpt2, lstm etc.

- added llama fp8 unit test for better code coverage, to reduce memory
required
- updated transformers version >= 4.28.0 for LLamaConfig
- set tokenizers version <= 0.20.3 to avoid 0.20.4 version
[issues](huggingface/tokenizers#1691) with
py3.8
- added llama2 to tensorrt
- llama2 tests not added to test_tensorrt_inference_performance.py due
to large memory requirement for worker gpu. tests validated separately
on gh200

---------

Co-authored-by: dpatlolla <[email protected]>
  • Loading branch information
dpower4 and dpatlolla authored Nov 28, 2024
1 parent 4e6935a commit 249e21c
Show file tree
Hide file tree
Showing 9 changed files with 408 additions and 7 deletions.
3 changes: 2 additions & 1 deletion docs/superbench-config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ A list of models to run, only supported in model-benchmark.
shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 |
squeezenet1_0 | squeezenet1_1 |
vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 |
bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl ]
bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl |
llama2-7b | llama2-13b | llama2-70b ]
```
* default value: `[ ]`

Expand Down
1 change: 1 addition & 0 deletions docs/user-tutorial/benchmarks/model-benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ id: model-benchmarks
Run training or inference tasks with single or half precision for deep learning models,
including the following categories:
* GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl
* LLAMA: llama2-7b, llama2-13b, llama2-70b
* BERT: bert-base and bert-large
* LSTM
* CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
Expand Down
41 changes: 41 additions & 0 deletions examples/benchmarks/pytorch_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Model benchmark example for Llama2-7b (32-layer, 4096-hidden, 32-heads, 7B parameters).
Commands to run:
python3 examples/benchmarks/pytorch_llama2.py (Single GPU)
python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_llama2.py \
--distributed (Distributed)
"""

import argparse

from superbench.benchmarks import Platform, Framework, BenchmarkRegistry
from superbench.common.utils import logger

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--distributed', action='store_true', default=False, help='Whether to enable distributed training.'
)
args = parser.parse_args()

# Specify the model name and benchmark parameters.
model_name = 'llama2-7b'
parameters = '--batch_size 1 --duration 120 --seq_len 512 --precision float16'
if args.distributed:
parameters += ' --distributed_impl ddp --distributed_backend nccl'

# Create context for Llama2 benchmark and run it for 120 seconds.
context = BenchmarkRegistry.create_benchmark_context(
model_name, platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH
)

benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,10 @@ def run(self):
'yapf==0.31.0',
],
'torch': [
'tokenizers<=0.20.3',
'torch>=1.7.0a0',
'torchvision>=0.8.0a0',
'transformers>=4.3.3, <4.23.0',
'transformers>=4.28.0',
],
'ort': [
'onnx>=1.10.2',
Expand Down
3 changes: 2 additions & 1 deletion superbench/benchmarks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def get_configurable_settings(self):
Return:
All configurable settings in raw string.
"""
return self._parser.format_help().strip()
message = self._parser.format_help().strip()
return message

def parse_args(self, ignore_invalid=False):
"""Parse the arguments.
Expand Down
40 changes: 37 additions & 3 deletions superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import torch.hub
import torch.onnx
import torchvision.models
from transformers import BertConfig, GPT2Config
from transformers import BertConfig, GPT2Config, LlamaConfig

from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel


class torch2onnxExporter():
Expand Down Expand Up @@ -87,6 +88,39 @@ def __init__(self):
),
self.num_classes,
),
'llama2-7b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
intermediate_size=11008,
),
self.num_classes,
),
'llama2-13b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=5120,
num_hidden_layers=40,
num_attention_heads=40,
num_key_value_heads=40,
intermediate_size=13824,
),
self.num_classes,
),
'llama2-70b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=8192,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
intermediate_size=28672,
),
self.num_classes,
),
}
self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -138,7 +172,7 @@ def export_torchvision_model(self, model_name, batch_size=1):
model,
dummy_input,
file_name,
opset_version=10,
opset_version=14,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
input_names=['input'],
output_names=['output'],
Expand Down Expand Up @@ -179,7 +213,7 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512):
model,
dummy_input,
file_name,
opset_version=10,
opset_version=14,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
Expand Down
2 changes: 1 addition & 1 deletion superbench/benchmarks/model_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT

__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT']
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama']
Loading

0 comments on commit 249e21c

Please sign in to comment.