diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index 4770625ba..741a73938 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -4,6 +4,7 @@ """A module containing all the e2e model related benchmarks.""" import sys +from typing import Optional from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT @@ -14,7 +15,7 @@ from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama # Check for Python version > 3.7 and conditionally import PytorchMixtral -PytorchMixtral = None +PytorchMixtral: Optional[type] = None if sys.version_info > (3, 7): from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral