Skip to content

Commit

Permalink
[Feat] add dataloader with different batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Oct 2, 2024
1 parent 6677303 commit 1540cd8
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,17 @@ def test(
print(f"Loading {dataset}")
td_test = env.load_data(dataset) # this also adds the bks cost
dataloader = get_dataloader(td_test, batch_size=opts.batch_size)
td_test = env.reset(td_test).to(device)

start = time.time()
out = test(policy, td_test, env, device=device)
res = []
for batch in dataloader:
td_test = env.reset(batch).to(device)
o = test(policy, td_test, env, device=device)
res.append(o)
out = {}
out["max_aug_reward"] = torch.cat([o["max_aug_reward"] for o in res])
out["gap_to_bks"] = torch.cat([o["gap_to_bks"] for o in res])

inference_time = time.time() - start

dataset_name = dataset.split("/")[-3].split(".")[0].upper()
Expand Down

0 comments on commit 1540cd8

Please sign in to comment.