diff --git a/tests/benchmarks/model_benchmarks/test_model_base.py b/tests/benchmarks/model_benchmarks/test_model_base.py index 1b6af1775..760c8b727 100644 --- a/tests/benchmarks/model_benchmarks/test_model_base.py +++ b/tests/benchmarks/model_benchmarks/test_model_base.py @@ -4,6 +4,7 @@ """Tests for BenchmarkRegistry module.""" import json +import sys from superbench.benchmarks import Platform, Framework, Precision, BenchmarkRegistry, BenchmarkType, ReturnCode from superbench.benchmarks.model_benchmarks import ModelBenchmark @@ -148,9 +149,9 @@ def test_arguments_related_interfaces(): # Test get_configurable_settings(). settings = benchmark.get_configurable_settings() - expected_settings = ( - """optional arguments: - --batch_size int The number of batch size. + + prefix = "options:" if sys.version_info >= (3, 10) else "optional arguments:" + args = """ --batch_size int The number of batch size. --distributed_backend DistributedBackend Distributed backends. E.g. nccl mpi gloo. --distributed_impl DistributedImpl @@ -177,7 +178,8 @@ def test_arguments_related_interfaces(): --run_count int The run count of benchmark. --sample_count int The number of data samples in dataset. --seq_len int Sequence length.""" - ) + + expected_settings = f"{prefix}\n{args}" assert (settings == expected_settings) @@ -188,9 +190,9 @@ def test_preprocess(): assert (benchmark._preprocess()) assert (benchmark.return_code == ReturnCode.SUCCESS) settings = benchmark.get_configurable_settings() - expected_settings = ( - """optional arguments: - --batch_size int The number of batch size. + + prefix = "options:" if sys.version_info >= (3, 10) else "optional arguments:" + args = """ --batch_size int The number of batch size. --distributed_backend DistributedBackend Distributed backends. E.g. nccl mpi gloo. --distributed_impl DistributedImpl @@ -217,7 +219,8 @@ def test_preprocess(): --run_count int The run count of benchmark. --sample_count int The number of data samples in dataset. --seq_len int Sequence length.""" - ) + + expected_settings = f"{prefix}\n{args}" assert (settings == expected_settings) # Negative case for _preprocess() - invalid precision. diff --git a/tests/benchmarks/test_registry.py b/tests/benchmarks/test_registry.py index fb36d1fe3..cd2777e58 100644 --- a/tests/benchmarks/test_registry.py +++ b/tests/benchmarks/test_registry.py @@ -4,6 +4,7 @@ """Tests for BenchmarkRegistry module.""" import re +import sys from superbench.benchmarks import Platform, Framework, BenchmarkType, BenchmarkRegistry, ReturnCode from superbench.benchmarks.micro_benchmarks import MicroBenchmark @@ -112,14 +113,16 @@ def test_get_benchmark_configurable_settings(): context = BenchmarkRegistry.create_benchmark_context('accumulation', platform=Platform.CPU) settings = BenchmarkRegistry.get_benchmark_configurable_settings(context) - expected = """optional arguments: - --duration int The elapsed time of benchmark in seconds. + prefix = "options:" if sys.version_info >= (3, 10) else "optional arguments:" + args = """ --duration int The elapsed time of benchmark in seconds. --log_flushing Real-time log flushing. --log_raw_data Log raw data into file instead of saving it into result object. --lower_bound int The lower bound for accumulation. --run_count int The run count of benchmark. --upper_bound int The upper bound for accumulation.""" + + expected = f"{prefix}\n{args}" assert (settings == expected)