From 1540cd8cd5ec5a56d06e96c4e2f529dd3ad1423b Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Wed, 2 Oct 2024 23:59:52 +0900 Subject: [PATCH] [Feat] add dataloader with different batch size --- test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test.py b/test.py index fecc086..5252e8f 100644 --- a/test.py +++ b/test.py @@ -195,10 +195,17 @@ def test( print(f"Loading {dataset}") td_test = env.load_data(dataset) # this also adds the bks cost dataloader = get_dataloader(td_test, batch_size=opts.batch_size) - td_test = env.reset(td_test).to(device) start = time.time() - out = test(policy, td_test, env, device=device) + res = [] + for batch in dataloader: + td_test = env.reset(batch).to(device) + o = test(policy, td_test, env, device=device) + res.append(o) + out = {} + out["max_aug_reward"] = torch.cat([o["max_aug_reward"] for o in res]) + out["gap_to_bks"] = torch.cat([o["gap_to_bks"] for o in res]) + inference_time = time.time() - start dataset_name = dataset.split("/")[-3].split(".")[0].upper()