Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
dpower4 committed Dec 19, 2024
1 parent 64abec0 commit 793713c
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from transformers import MixtralConfig
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel


class torch2onnxExporter():
"""PyTorch model to ONNX exporter."""

def __init__(self):
"""Constructor."""
self.num_classes = 100
Expand Down Expand Up @@ -132,32 +134,36 @@ def __init__(self):

# Only include Mixtral models if MixtralBenchmarkModel is available
if MixtralBenchmarkModel is not None:
self.benchmark_models.update({
'mixtral-8x7b': lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
intermediate_size=14336,
max_position_embeddings=32768,
router_aux_loss_coef=0.02,
self.benchmark_models.update(
{
'mixtral-8x7b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
intermediate_size=14336,
max_position_embeddings=32768,
router_aux_loss_coef=0.02,
),
self.num_classes,
),
self.num_classes,
),
'mixtral-8x22b': lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=6144,
num_hidden_layers=56,
num_attention_heads=48,
num_key_value_heads=8,
intermediate_size=16384,
max_position_embeddings=65536,
router_aux_loss_coef=0.001,
'mixtral-8x22b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=6144,
num_hidden_layers=56,
num_attention_heads=48,
num_key_value_heads=8,
intermediate_size=16384,
max_position_embeddings=65536,
router_aux_loss_coef=0.001,
),
self.num_classes,
),
self.num_classes,
),
})
}
)

self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True)
Expand Down

0 comments on commit 793713c

Please sign in to comment.