Skip to content

Commit

Permalink
fix formatting related test cases failure for 3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
dpower4 committed Nov 21, 2024
1 parent 297a229 commit 1c6f908
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
19 changes: 11 additions & 8 deletions tests/benchmarks/model_benchmarks/test_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Tests for BenchmarkRegistry module."""

import json
import sys

from superbench.benchmarks import Platform, Framework, Precision, BenchmarkRegistry, BenchmarkType, ReturnCode
from superbench.benchmarks.model_benchmarks import ModelBenchmark
Expand Down Expand Up @@ -148,9 +149,9 @@ def test_arguments_related_interfaces():

# Test get_configurable_settings().
settings = benchmark.get_configurable_settings()
expected_settings = (
"""optional arguments:
--batch_size int The number of batch size.

prefix = "options:" if sys.version_info >= (3, 10) else "optional arguments:"
args = """ --batch_size int The number of batch size.
--distributed_backend DistributedBackend
Distributed backends. E.g. nccl mpi gloo.
--distributed_impl DistributedImpl
Expand All @@ -177,7 +178,8 @@ def test_arguments_related_interfaces():
--run_count int The run count of benchmark.
--sample_count int The number of data samples in dataset.
--seq_len int Sequence length."""
)

expected_settings = f"{prefix}\n{args}"
assert (settings == expected_settings)


Expand All @@ -188,9 +190,9 @@ def test_preprocess():
assert (benchmark._preprocess())
assert (benchmark.return_code == ReturnCode.SUCCESS)
settings = benchmark.get_configurable_settings()
expected_settings = (
"""optional arguments:
--batch_size int The number of batch size.

prefix = "options:" if sys.version_info >= (3, 10) else "optional arguments:"
args = """ --batch_size int The number of batch size.
--distributed_backend DistributedBackend
Distributed backends. E.g. nccl mpi gloo.
--distributed_impl DistributedImpl
Expand All @@ -217,7 +219,8 @@ def test_preprocess():
--run_count int The run count of benchmark.
--sample_count int The number of data samples in dataset.
--seq_len int Sequence length."""
)

expected_settings = f"{prefix}\n{args}"
assert (settings == expected_settings)

# Negative case for _preprocess() - invalid precision.
Expand Down
7 changes: 5 additions & 2 deletions tests/benchmarks/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Tests for BenchmarkRegistry module."""

import re
import sys

from superbench.benchmarks import Platform, Framework, BenchmarkType, BenchmarkRegistry, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
Expand Down Expand Up @@ -112,14 +113,16 @@ def test_get_benchmark_configurable_settings():
context = BenchmarkRegistry.create_benchmark_context('accumulation', platform=Platform.CPU)
settings = BenchmarkRegistry.get_benchmark_configurable_settings(context)

expected = """optional arguments:
--duration int The elapsed time of benchmark in seconds.
prefix = "options:" if sys.version_info >= (3, 10) else "optional arguments:"
args = """ --duration int The elapsed time of benchmark in seconds.
--log_flushing Real-time log flushing.
--log_raw_data Log raw data into file instead of saving it into result
object.
--lower_bound int The lower bound for accumulation.
--run_count int The run count of benchmark.
--upper_bound int The upper bound for accumulation."""

expected = f"{prefix}\n{args}"
assert (settings == expected)


Expand Down

0 comments on commit 1c6f908

Please sign in to comment.