diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index b24879097..4770625ba 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -13,9 +13,12 @@ from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama -__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] - # Check for Python version > 3.7 and conditionally import PytorchMixtral +PytorchMixtral = None if sys.version_info > (3, 7): from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral + +__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] + +if PytorchMixtral is not None: __all__.append('PytorchMixtral')