diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 236072fa..62aac95b 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -918,7 +918,7 @@ def run(args: argparse.Namespace) -> None: distributed=args.distributed, ) logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) - if not test_data_loader: + if not test_data_loader: table_test = create_error_table( table_type=args.error_table, all_data_loaders=test_data_loader, diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 9ae21e83..0cdb7cfa 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -468,7 +468,7 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): fitting_configs_dft = [] fitting_configs_mp2 = [] for i, c in enumerate(fitting_configs): - if i == 0 or i == 1: + if i in (0, 1): c_dft = c.copy() c_dft.info["head"] = "DFT" fitting_configs_dft.append(c_dft)