Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarks - Add LLaMA-2 Models #668

Merged
merged 30 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
36bbf10
add llama init template
dpower4 Nov 18, 2024
697138a
add llama2 unit test
dpower4 Nov 18, 2024
9355e22
fix dims for llama2 unit test
dpower4 Nov 18, 2024
60eae36
update transformers version for LLamaConfig
dpower4 Nov 18, 2024
dadb56a
update docs
dpower4 Nov 18, 2024
d2731a8
update opset for torch onnx conversion
dpower4 Nov 19, 2024
3644985
format and lint
dpower4 Nov 19, 2024
52f4900
remove remnant
dpower4 Nov 19, 2024
f826676
lint fix
dpower4 Nov 20, 2024
6a41087
replace py 3.6 with 3.10 and update cuda to 12.4 for unit test
dpower4 Nov 20, 2024
5b816b4
remove 3.6 from setup, codecov and docs
dpower4 Nov 20, 2024
f322c98
add llama fp8 unit test for better code coverage
dpower4 Nov 20, 2024
b28ee17
llama fp8 precision test only, to reduce memory required
dpower4 Nov 20, 2024
e6f6be3
lint fix
dpower4 Nov 20, 2024
297a229
remove deprecated NaN usage for numpy>2.0
dpower4 Nov 21, 2024
5f72f51
fix argparse formatting related test cases failure for 3.10
dpower4 Nov 21, 2024
8bbe326
fix lint
dpower4 Nov 21, 2024
50452ef
fix lint
dpower4 Nov 21, 2024
0b1da4f
add llama2 to tensorrt
dpower4 Nov 21, 2024
97b3d72
Merge branch 'main' into feat/llama
dpower4 Nov 22, 2024
e423aec
add more params to llama config
Nov 22, 2024
d850210
fix lint
dpower4 Nov 22, 2024
b08f9e3
Merge branch 'main' into feat/llama
abuccts Nov 22, 2024
54c3e85
Merge branch 'main' into feat/llama
abuccts Nov 27, 2024
670dc76
llama test: use fp16 instead of fp8 to relax cuda CC req.
dpower4 Nov 27, 2024
27c788c
fix comment and lint
dpower4 Nov 27, 2024
00d09ba
fix precision arg as float16
dpower4 Nov 27, 2024
bd47fc3
limit tokenizers version to < 0.20.3 as 0.20.4 doesnt support py3.8
dpower4 Nov 27, 2024
bed3e01
address review comments
dpower4 Nov 28, 2024
1570707
Merge branch 'main' into feat/llama
abuccts Nov 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/superbench-config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ A list of models to run, only supported in model-benchmark.
shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 |
squeezenet1_0 | squeezenet1_1 |
vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 |
bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl ]
bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl |
llama2-7b | llama2-13b | llama2-70b ]
```
* default value: `[ ]`

Expand Down
1 change: 1 addition & 0 deletions docs/user-tutorial/benchmarks/model-benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ id: model-benchmarks
Run training or inference tasks with single or half precision for deep learning models,
including the following categories:
* GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl
* LLAMA: llama2-7b, llama2-13b, llama2-70b
* BERT: bert-base and bert-large
* LSTM
* CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
Expand Down
41 changes: 41 additions & 0 deletions examples/benchmarks/pytorch_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Model benchmark example for Llama2-7b (32-layer, 4096-hidden, 32-heads, 7B parameters).

Commands to run:
python3 examples/benchmarks/pytorch_llama2.py (Single GPU)
python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_llama2.py \
--distributed (Distributed)
"""

import argparse

from superbench.benchmarks import Platform, Framework, BenchmarkRegistry
from superbench.common.utils import logger

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--distributed', action='store_true', default=False, help='Whether to enable distributed training.'
)
args = parser.parse_args()

# Specify the model name and benchmark parameters.
model_name = 'llama2-7b'
parameters = '--batch_size 1 --duration 120 --seq_len 512 --precision float16'
if args.distributed:
parameters += ' --distributed_impl ddp --distributed_backend nccl'

# Create context for Llama2 benchmark and run it for 120 seconds.
context = BenchmarkRegistry.create_benchmark_context(
model_name, platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH
)

benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,10 @@ def run(self):
'yapf==0.31.0',
],
'torch': [
'tokenizers<=0.20.3',
'torch>=1.7.0a0',
'torchvision>=0.8.0a0',
'transformers>=4.3.3, <4.23.0',
'transformers>=4.28.0',
],
'ort': [
'onnx>=1.10.2',
Expand Down
3 changes: 2 additions & 1 deletion superbench/benchmarks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def get_configurable_settings(self):
Return:
All configurable settings in raw string.
"""
return self._parser.format_help().strip()
message = self._parser.format_help().strip()
return message

def parse_args(self, ignore_invalid=False):
"""Parse the arguments.
Expand Down
40 changes: 37 additions & 3 deletions superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import torch.hub
import torch.onnx
import torchvision.models
from transformers import BertConfig, GPT2Config
from transformers import BertConfig, GPT2Config, LlamaConfig

from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel


class torch2onnxExporter():
Expand Down Expand Up @@ -87,6 +88,39 @@ def __init__(self):
),
self.num_classes,
),
'llama2-7b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
intermediate_size=11008,
),
self.num_classes,
),
'llama2-13b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=5120,
num_hidden_layers=40,
num_attention_heads=40,
num_key_value_heads=40,
intermediate_size=13824,
),
self.num_classes,
),
'llama2-70b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=8192,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
intermediate_size=28672,
),
self.num_classes,
),
}
self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -138,7 +172,7 @@ def export_torchvision_model(self, model_name, batch_size=1):
model,
dummy_input,
file_name,
opset_version=10,
opset_version=14,
abuccts marked this conversation as resolved.
Show resolved Hide resolved
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
input_names=['input'],
output_names=['output'],
Expand Down Expand Up @@ -179,7 +213,7 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512):
model,
dummy_input,
file_name,
opset_version=10,
opset_version=14,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
Expand Down
2 changes: 1 addition & 1 deletion superbench/benchmarks/model_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT

__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT']
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama']
Loading
Loading