Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
dpower4 committed Dec 18, 2024
1 parent deeedde commit 5f5a75d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/superbench-config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ A list of models to run, only supported in model-benchmark.
squeezenet1_0 | squeezenet1_1 |
vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 |
bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl |
llama2-7b | llama2-13b | llama2-70b ]
llama2-7b | llama2-13b | llama2-70b |
mixtral-8x7b | mixtral-8x22b ]
```
* default value: `[ ]`

Expand Down
1 change: 1 addition & 0 deletions docs/user-tutorial/benchmarks/model-benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Run training or inference tasks with single or half precision for deep learning
including the following categories:
* GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl
* LLAMA: llama2-7b, llama2-13b, llama2-70b
* MoE: mixtral-8x7b, mixtral-8x22b
* BERT: bert-base and bert-large
* LSTM
* CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
Expand Down
29 changes: 28 additions & 1 deletion superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import torch.hub
import torch.onnx
import torchvision.models
from transformers import BertConfig, GPT2Config, LlamaConfig
from transformers import BertConfig, GPT2Config, LlamaConfig, MixtralConfig

from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel


class torch2onnxExporter():
Expand Down Expand Up @@ -121,6 +122,32 @@ def __init__(self):
),
self.num_classes,
),
'mixtral-8x7b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
intermediate_size=14336,
max_position_embeddings=32768,
router_aux_loss_coef=0.02,
),
self.num_classes,
),
'mixtral-8x22b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=6144,
num_hidden_layers=56,
num_attention_heads=48,
num_key_value_heads=8,
intermediate_size=16384,
max_position_embeddings=65536,
router_aux_loss_coef=0.001,
),
self.num_classes,
),
}
self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True)
Expand Down
3 changes: 2 additions & 1 deletion superbench/benchmarks/model_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT

__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama']
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT',
'PytorchLlama', 'PytorchMixtral']

0 comments on commit 5f5a75d

Please sign in to comment.