diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 250942267..82ee21dda 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -41,6 +41,22 @@ def test_set_logger(self): expected_log_file = Path(self.runner._sb_output_dir) / 'sb-run.log' self.assertTrue(expected_log_file.is_file()) + def test_validate_sb_config(self): + """Test validate_sb_config.""" + self.runner._SuperBenchRunner__validate_sb_config() + self.assertIn('env', self.runner._sb_config.superbench) + for name in self.runner._sb_benchmarks: + self.assertIn('modes', self.runner._sb_config.superbench.benchmarks[name]) + for mode in self.runner._sb_config.superbench.benchmarks[name].modes: + self.assertIn('env', mode) + if mode.name == 'local': + self.assertIn('proc_num', mode) + self.assertIn('prefix', mode) + if mode.name == 'torch.distributed': + self.assertIn('proc_num', mode) + if mode.name == 'mpi': + self.assertIn('mca', mode) + def test_get_failure_count(self): """Test get_failure_count.""" self.assertEqual(0, self.runner.get_failure_count()) @@ -410,3 +426,24 @@ def test_generate_metric_name(self): test_case['run_count'], test_case['curr_rank'], test_case['curr_run'] ), test_case['expected'] ) + + def test_run_proc_timeout(self): + "Test run_proc_timeout." + self.runner._sb_benchmarks = { + 'benchmark1': {'timeout': 120}, + 'benchmark2': {'timeout': None}, + 'benchmark3': {'timeout': 30}, + } + + test_cases = [ + ('benchmark1', 120), + ('benchmark2', None), + ('benchmark3', 60), + ] + + for benchmark_name, expected_timeout in test_cases: + with self.subTest(benchmark_name=benchmark_name): + timeout = self.runner._sb_benchmarks[benchmark_name].get('timeout', None) + if isinstance(timeout, int): + timeout = max(timeout, 60) + self.assertEqual(timeout, expected_timeout)