From 5f5a75d583488f079f47d4443a93c9640b3b8194 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Wed, 18 Dec 2024 13:27:22 -0800 Subject: [PATCH] update docs --- docs/superbench-config.mdx | 3 +- .../benchmarks/model-benchmarks.md | 1 + .../micro_benchmarks/_export_torch_to_onnx.py | 29 ++++++++++++++++++- .../benchmarks/model_benchmarks/__init__.py | 3 +- 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/docs/superbench-config.mdx b/docs/superbench-config.mdx index 051abeda3..7bc8748a6 100644 --- a/docs/superbench-config.mdx +++ b/docs/superbench-config.mdx @@ -329,7 +329,8 @@ A list of models to run, only supported in model-benchmark. squeezenet1_0 | squeezenet1_1 | vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 | bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl | - llama2-7b | llama2-13b | llama2-70b ] + llama2-7b | llama2-13b | llama2-70b | + mixtral-8x7b | mixtral-8x22b ] ``` * default value: `[ ]` diff --git a/docs/user-tutorial/benchmarks/model-benchmarks.md b/docs/user-tutorial/benchmarks/model-benchmarks.md index 71e8832cf..ba89ed6ff 100644 --- a/docs/user-tutorial/benchmarks/model-benchmarks.md +++ b/docs/user-tutorial/benchmarks/model-benchmarks.md @@ -14,6 +14,7 @@ Run training or inference tasks with single or half precision for deep learning including the following categories: * GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl * LLAMA: llama2-7b, llama2-13b, llama2-70b +* MoE: mixtral-8x7b, mixtral-8x22b * BERT: bert-base and bert-large * LSTM * CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including: diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index 0f28f4f6a..dc2f7f585 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -9,12 +9,13 @@ import torch.hub import torch.onnx import torchvision.models -from transformers import BertConfig, GPT2Config, LlamaConfig +from transformers import BertConfig, GPT2Config, LlamaConfig, MixtralConfig from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel +from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel class torch2onnxExporter(): @@ -121,6 +122,32 @@ def __init__(self): ), self.num_classes, ), + 'mixtral-8x7b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=14336, + max_position_embeddings=32768, + router_aux_loss_coef=0.02, + ), + self.num_classes, + ), + 'mixtral-8x22b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=6144, + num_hidden_layers=56, + num_attention_heads=48, + num_key_value_heads=8, + intermediate_size=16384, + max_position_embeddings=65536, + router_aux_loss_coef=0.001, + ), + self.num_classes, + ), } self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx' self._onnx_model_path.mkdir(parents=True, exist_ok=True) diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index 0829c4d33..cc4967283 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -10,4 +10,5 @@ from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT -__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] +__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', + 'PytorchLlama', 'PytorchMixtral']