Skip to content

Commit

Permalink
enable py3.7 checks for mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
dpower4 committed Dec 19, 2024
1 parent 0e4e9c6 commit 64abec0
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 29 deletions.
64 changes: 37 additions & 27 deletions superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,26 @@

"""Export PyTorch models to ONNX format."""

import sys
from pathlib import Path

from packaging import version
import torch.hub
import torch.onnx
import torchvision.models
from transformers import BertConfig, GPT2Config, LlamaConfig, MixtralConfig
from transformers import BertConfig, GPT2Config, LlamaConfig

from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel

# Check Python version and skip Mixtral if Python is 3.7 or lower
if sys.version_info <= (3, 7):
MixtralBenchmarkModel = None
else:
from transformers import MixtralConfig
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel

class torch2onnxExporter():
"""PyTorch model to ONNX exporter."""
Expand Down Expand Up @@ -122,33 +128,37 @@ def __init__(self):
),
self.num_classes,
),
'mixtral-8x7b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
intermediate_size=14336,
max_position_embeddings=32768,
router_aux_loss_coef=0.02,
}

# Only include Mixtral models if MixtralBenchmarkModel is available
if MixtralBenchmarkModel is not None:
self.benchmark_models.update({
'mixtral-8x7b': lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
intermediate_size=14336,
max_position_embeddings=32768,
router_aux_loss_coef=0.02,
),
self.num_classes,
),
self.num_classes,
),
'mixtral-8x22b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=6144,
num_hidden_layers=56,
num_attention_heads=48,
num_key_value_heads=8,
intermediate_size=16384,
max_position_embeddings=65536,
router_aux_loss_coef=0.001,
'mixtral-8x22b': lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=6144,
num_hidden_layers=56,
num_attention_heads=48,
num_key_value_heads=8,
intermediate_size=16384,
max_position_embeddings=65536,
router_aux_loss_coef=0.001,
),
self.num_classes,
),
self.num_classes,
),
}
})

self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True)

Expand Down
8 changes: 7 additions & 1 deletion superbench/benchmarks/model_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@

"""A module containing all the e2e model related benchmarks."""

import sys

from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark
from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2
from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT
from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral

__all__ = [
'ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama',
'PytorchMixtral'
]

# Check for Python version > 3.7 and conditionally import PytorchMixtral
if sys.version_info > (3, 7):
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral
__all__.append('PytorchMixtral')
2 changes: 2 additions & 0 deletions superbench/benchmarks/model_benchmarks/pytorch_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def _inference_step(self, precision):


# Register Mixtral benchmark with 8x7b parameters.
# Ref: https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
BenchmarkRegistry.register_benchmark(
'pytorch-mixtral-8x7b',
PytorchMixtral,
Expand All @@ -263,6 +264,7 @@ def _inference_step(self, precision):
)

# Register Mixtral benchmark with 8x22b parameters.
# Ref: https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/blob/main/config.json
BenchmarkRegistry.register_benchmark(
'pytorch-mixtral-8x22b',
PytorchMixtral,
Expand Down
7 changes: 6 additions & 1 deletion tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@

"""Tests for mixtral model benchmarks."""

import sys

from tests.helper import decorator
from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral

# Check for Python version 3.8 or greater and conditionally import PytorchMixtral
if sys.version_info > (3, 7):
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral


@decorator.cuda_test
Expand Down

0 comments on commit 64abec0

Please sign in to comment.