Skip to content

Commit

Permalink
[Docs] improve documentation; add reproducibility notice
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Sep 24, 2024
1 parent 70c40e0 commit ebd0c1f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,16 @@ Note on legacy FFSP code: the initial version we made was not yet integrated in

### Testing

You may run the `test.py` script to evaluate the model, e.g. with:
You may run the `test.py` script to evaluate the model, e.g. with greedy decoding:

```bash
python test.py --problem hcvrp --decode_type greedy --batch_size 128 --sample_size 1
python test.py --problem hcvrp --decode_type greedy --batch_size 128
```

(note: we measure time with single instance -- batch size 1, but larger makes the overall evaluation faster), or with sampling:

```bash
python test.py --problem hcvrp --decode_type sampling --batch_size 1 --sample_size 1280
```


Expand Down
17 changes: 10 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
parser.add_argument(
"--sample_size",
type=int,
default=1,
default=None,
help="Number of samples to use for sampling decoding",
)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--device", type=str, default="cuda")

Expand All @@ -68,12 +68,15 @@
), "Problem must be specified if checkpoint is not provided"
checkpoint_path = f"./checkpoints/{problem}/parco.ckpt"
if decode_type == "greedy":
assert sample_size == 1
assert (sample_size == 1 or sample_size is None), "Greedy decoding only uses 1 sample"
if opts.datasets is None:
assert problem is not None, "Problem must be specified if dataset is not provided"
data_paths = [f"./data/{problem}/{f}" for f in os.listdir(f"./data/{problem}")]
else:
data_paths = [opts.datasets] if isinstance(opts.datasets, str) else opts.datasets
if decode_type == "sampling":
assert sample_size is not None, "Sample size must be specified for sampling decoding with --sample_size"
assert batch_size == 1, "Only batch_size=1 is supported for sampling decoding currently"
data_paths = sorted(data_paths) # Sort for consistency

# Load the checkpoint as usual
Expand All @@ -85,7 +88,7 @@
policy = model.policy.to(device).eval() # Use mixed precision if supported

for dataset in data_paths:
tour_lengths = []
costs = []
inference_times = []
eval_steps = []

Expand All @@ -108,15 +111,15 @@
end_time = time.time()
inference_time = end_time - start_time
if decode_type == "greedy":
tour_lengths.append(-out["reward"].mean().item())
costs.append(-out["reward"].mean().item())
else:
tour_lengths.extend(
costs.extend(
-out["reward"].reshape(-1, sample_size).max(dim=-1)[0]
)
inference_times.append(inference_time)
eval_steps.append(out["steps"])

print(f"Average tour length: {sum(tour_lengths)/len(tour_lengths):.4f}")
print(f"Average cost: {sum(costs)/len(costs):.4f}")
print(
f"Per step inference time: {sum(inference_times)/len(inference_times):.4f}s"
)
Expand Down

0 comments on commit ebd0c1f

Please sign in to comment.