diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index 00bd58fe9..9cc2dd65f 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -24,8 +24,10 @@ from transformers import MixtralConfig from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel + class torch2onnxExporter(): """PyTorch model to ONNX exporter.""" + def __init__(self): """Constructor.""" self.num_classes = 100 @@ -132,32 +134,36 @@ def __init__(self): # Only include Mixtral models if MixtralBenchmarkModel is available if MixtralBenchmarkModel is not None: - self.benchmark_models.update({ - 'mixtral-8x7b': lambda: MixtralBenchmarkModel( - MixtralConfig( - hidden_size=4096, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - intermediate_size=14336, - max_position_embeddings=32768, - router_aux_loss_coef=0.02, + self.benchmark_models.update( + { + 'mixtral-8x7b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=14336, + max_position_embeddings=32768, + router_aux_loss_coef=0.02, + ), + self.num_classes, ), - self.num_classes, - ), - 'mixtral-8x22b': lambda: MixtralBenchmarkModel( - MixtralConfig( - hidden_size=6144, - num_hidden_layers=56, - num_attention_heads=48, - num_key_value_heads=8, - intermediate_size=16384, - max_position_embeddings=65536, - router_aux_loss_coef=0.001, + 'mixtral-8x22b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=6144, + num_hidden_layers=56, + num_attention_heads=48, + num_key_value_heads=8, + intermediate_size=16384, + max_position_embeddings=65536, + router_aux_loss_coef=0.001, + ), + self.num_classes, ), - self.num_classes, - ), - }) + } + ) self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx' self._onnx_model_path.mkdir(parents=True, exist_ok=True)