From ebd0c1f6bf4c052a1961c96dc49f8d3523932d9f Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Tue, 24 Sep 2024 23:48:46 +0900 Subject: [PATCH] [Docs] improve documentation; add reproducibility notice --- README.md | 10 ++++++++-- test.py | 17 ++++++++++------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 38cf1f2..2c6ce3a 100644 --- a/README.md +++ b/README.md @@ -64,10 +64,16 @@ Note on legacy FFSP code: the initial version we made was not yet integrated in ### Testing -You may run the `test.py` script to evaluate the model, e.g. with: +You may run the `test.py` script to evaluate the model, e.g. with greedy decoding: ```bash -python test.py --problem hcvrp --decode_type greedy --batch_size 128 --sample_size 1 +python test.py --problem hcvrp --decode_type greedy --batch_size 128 +``` + +(note: we measure time with single instance -- batch size 1, but larger makes the overall evaluation faster), or with sampling: + +```bash +python test.py --problem hcvrp --decode_type sampling --batch_size 1 --sample_size 1280 ``` diff --git a/test.py b/test.py index 2d27dd2..4ae7fc5 100644 --- a/test.py +++ b/test.py @@ -40,10 +40,10 @@ parser.add_argument( "--sample_size", type=int, - default=1, + default=None, help="Number of samples to use for sampling decoding", ) - parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--checkpoint", type=str, default=None) parser.add_argument("--device", type=str, default="cuda") @@ -68,12 +68,15 @@ ), "Problem must be specified if checkpoint is not provided" checkpoint_path = f"./checkpoints/{problem}/parco.ckpt" if decode_type == "greedy": - assert sample_size == 1 + assert (sample_size == 1 or sample_size is None), "Greedy decoding only uses 1 sample" if opts.datasets is None: assert problem is not None, "Problem must be specified if dataset is not provided" data_paths = [f"./data/{problem}/{f}" for f in os.listdir(f"./data/{problem}")] else: data_paths = [opts.datasets] if isinstance(opts.datasets, str) else opts.datasets + if decode_type == "sampling": + assert sample_size is not None, "Sample size must be specified for sampling decoding with --sample_size" + assert batch_size == 1, "Only batch_size=1 is supported for sampling decoding currently" data_paths = sorted(data_paths) # Sort for consistency # Load the checkpoint as usual @@ -85,7 +88,7 @@ policy = model.policy.to(device).eval() # Use mixed precision if supported for dataset in data_paths: - tour_lengths = [] + costs = [] inference_times = [] eval_steps = [] @@ -108,15 +111,15 @@ end_time = time.time() inference_time = end_time - start_time if decode_type == "greedy": - tour_lengths.append(-out["reward"].mean().item()) + costs.append(-out["reward"].mean().item()) else: - tour_lengths.extend( + costs.extend( -out["reward"].reshape(-1, sample_size).max(dim=-1)[0] ) inference_times.append(inference_time) eval_steps.append(out["steps"]) - print(f"Average tour length: {sum(tour_lengths)/len(tour_lengths):.4f}") + print(f"Average cost: {sum(costs)/len(costs):.4f}") print( f"Per step inference time: {sum(inference_times)/len(inference_times):.4f}s" )