diff --git a/notebooks/tutorials/5-test-on-tsplib.ipynb b/notebooks/tutorials/5-test-on-tsplib.ipynb new file mode 100644 index 00000000..43dd7f23 --- /dev/null +++ b/notebooks/tutorials/5-test-on-tsplib.ipynb @@ -0,0 +1,632 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test Model on TSPLib\n", + "\n", + "In this notebook, we will test the trained model's performance on the TSPLib benchmark. We will use the trained model from the previous notebook.\n", + "\n", + "[TSPLib](http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/) is a library of sample instances for the TSP (and related problems) from various sources and of various types. In the TSPLib, there are several problems, including *TSP, HCP, ATSP*, etc. In this notebook, we will focus on testing the model on the TSP problem.\n", + "\n", + "## Before we start\n", + "\n", + "Before we test the model on TSPLib dataset, we need to prepare the dataset first by the following steps:\n", + "\n", + "**Step 1**. You may come to [here](http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/) to download the *symmetric traveling salesman problem* data in TSPLib dataset and unzip to a folder;\n", + "\n", + "Note that the downloaded data is `gzip` file with the following file tree:\n", + "```\n", + ".\n", + "└── ALL_tsp/\n", + " ├── a280.opt.tour.gz\n", + " ├── a280.tsp.gz\n", + " ├── ali535.tsp.gz\n", + " └── ... (other problems)\n", + "```\n", + "We need to unzip the `gzip` file to get the `.tsp` and `.opt.tour` files. We can use the following command to unzip them to the same folder:\n", + "```bash\n", + "find . -name \"*.gz\" -exec gunzip {} +\n", + "```\n", + "\n", + "After doing this, we will get the following file tree:\n", + "```\n", + ".\n", + "└── ALL_tsp/\n", + " ├── a280.opt.tour\n", + " ├── a280.tsp\n", + " ├── ali535.tsp\n", + " └── ... (other problems)\n", + "```\n", + "\n", + "**Step 2**. To read the TSPLib problem and optimal solution, we choose to use the `tsplib95` package. Use `pip install tsplib95` to install the package." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation\n", + "\n", + "Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!\n", + "\n", + "> Note: You may need to restart the runtime in Colab after this\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install rl4co[graph] # include torch-geometric\n", + "\n", + "## NOTE: to install latest version from Github (may be unstable) install from source instead:\n", + "# !pip install git+https://github.com/ai4co/rl4co.git" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install the `tsplib95` package\n", + "# !pip install tsplib95" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cbhua/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import os\n", + "import re\n", + "import torch\n", + "\n", + "from rl4co.envs import TSPEnv, CVRPEnv\n", + "from rl4co.models.zoo.am import AttentionModel\n", + "from rl4co.utils.trainer import RL4COTrainer\n", + "from rl4co.models.nn.utils import get_log_likelihood\n", + "from rl4co.models.zoo import EAS, EASLay, EASEmb, ActiveSearch\n", + "\n", + "from math import ceil\n", + "from einops import repeat\n", + "from tsplib95.loaders import load_problem, load_solution" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load a trained model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cbhua/miniconda3/envs/rl4co-user/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n", + "/home/cbhua/miniconda3/envs/rl4co-user/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n", + "/home/cbhua/miniconda3/envs/rl4co-user/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:177: Found keys that are not in the model state dict but in the checkpoint: ['baseline.baseline.model.encoder.init_embedding.init_embed.weight', 'baseline.baseline.model.encoder.init_embedding.init_embed.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.num_batches_tracked', 'baseline.baseline.model.decoder.context_embedding.W_placeholder', 'baseline.baseline.model.decoder.context_embedding.project_context.weight', 'baseline.baseline.model.decoder.project_node_embeddings.weight', 'baseline.baseline.model.decoder.project_fixed_context.weight', 'baseline.baseline.model.decoder.logit_attention.project_out.weight']\n" + ] + } + ], + "source": [ + "# Load from checkpoint; alternatively, simply instantiate a new model\n", + "# Note the model is trained for TSP problem\n", + "checkpoint_path = \"../tsp-20.ckpt\" # modify the path to your checkpoint file\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# load checkpoint\n", + "# checkpoint = torch.load(checkpoint_path)\n", + "\n", + "lit_model = AttentionModel.load_from_checkpoint(checkpoint_path, load_baseline=False)\n", + "policy, env = lit_model.policy, lit_model.env\n", + "policy = policy.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load tsp problems\n", + "\n", + "Note that in the TSPLib, only part of the problems have optimal solutions with the same problem name but with `.opt.tour` suffix. For example, `a280.tsp` has the optimal solution `a280.opt.tour`. " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the problem from TSPLib\n", + "tsplib_dir = './tsplib'# modify this to the directory of your prepared files\n", + "files = os.listdir(tsplib_dir)\n", + "problem_files_full = [file for file in files if file.endswith('.tsp')]\n", + "\n", + "# Load the optimal solution files from TSPLib\n", + "solution_files = [file for file in files if file.endswith('.opt.tour')] " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Utils function\n", + "def normalize_coord(coord:torch.Tensor) -> torch.Tensor:\n", + " x, y = coord[:, 0], coord[:, 1]\n", + " x_min, x_max = x.min(), x.max()\n", + " y_min, y_max = y.min(), y.max()\n", + " \n", + " x_scaled = (x - x_min) / (x_max - x_min) \n", + " y_scaled = (y - y_min) / (y_max - y_min)\n", + " coord_scaled = torch.stack([x_scaled, y_scaled], dim=1)\n", + " return coord_scaled " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test the greedy\n", + "\n", + "Note that run all experiments will take long time and require large VRAM. For simple testing, we can use a subset of the problems. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Customized problem cases\n", + "problem_files_custom = [\n", + " \"eil51.tsp\", \"berlin52.tsp\", \"st70.tsp\", \"eil76.tsp\", \n", + " \"pr76.tsp\", \"rat99.tsp\", \"kroA100.tsp\", \"kroB100.tsp\", \n", + " \"kroC100.tsp\", \"kroD100.tsp\", \"kroE100.tsp\", \"rd100.tsp\", \n", + " \"eil101.tsp\", \"lin105.tsp\", \"pr124.tsp\", \"bier127.tsp\", \n", + " \"ch130.tsp\", \"pr136.tsp\", \"pr144.tsp\", \"kroA150.tsp\", \n", + " \"kroB150.tsp\", \"pr152.tsp\", \"u159.tsp\", \"rat195.tsp\", \n", + " \"kroA200.tsp\", \"ts225.tsp\", \"tsp225.tsp\", \"pr226.tsp\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3883036/2632546508.py:5: DeprecationWarning: Call to deprecated function (or staticmethod) load_problem. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", + " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", + "/tmp/ipykernel_3883036/2632546508.py:43: DeprecationWarning: Call to deprecated function (or staticmethod) load_solution. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", + " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Problem: eil51 Cost: 493 Optimal Cost: 426 \t Gap: 15.73%\n", + "problem: eil51 cost: 493 \n", + "problem: berlin52 cost: 8957 \n", + "Problem: st70 Cost: 806 Optimal Cost: 675 \t Gap: 19.41%\n", + "problem: st70 cost: 806 \n", + "Problem: eil76 Cost: 693 Optimal Cost: 538 \t Gap: 28.81%\n", + "problem: eil76 cost: 693 \n", + "Problem: pr76 Cost: 132292 Optimal Cost: 108159 \t Gap: 22.31%\n", + "problem: pr76 cost: 132292 \n", + "problem: rat99 cost: 2053 \n", + "Problem: kroA100 Cost: 30791 Optimal Cost: 21282 \t Gap: 44.68%\n", + "problem: kroA100 cost: 30791 \n", + "problem: kroB100 cost: 30347 \n", + "Problem: kroC100 Cost: 28339 Optimal Cost: 20749 \t Gap: 36.58%\n", + "problem: kroC100 cost: 28339 \n", + "Problem: kroD100 Cost: 27600 Optimal Cost: 21294 \t Gap: 29.61%\n", + "problem: kroD100 cost: 27600 \n", + "problem: kroE100 cost: 28396 \n", + "Problem: rd100 Cost: 10695 Optimal Cost: 7910 \t Gap: 35.21%\n", + "problem: rd100 cost: 10695 \n", + "problem: eil101 cost: 919 \n", + "Problem: lin105 Cost: 21796 Optimal Cost: 14379 \t Gap: 51.58%\n", + "problem: lin105 cost: 21796 \n", + "problem: pr124 cost: 75310 \n", + "problem: bier127 cost: 177471 \n", + "problem: ch130 cost: 8169 \n", + "problem: pr136 cost: 135974 \n", + "problem: pr144 cost: 71599 \n", + "problem: kroA150 cost: 40376 \n", + "problem: kroB150 cost: 37076 \n", + "problem: pr152 cost: 94805 \n", + "problem: u159 cost: 64768 \n", + "problem: rat195 cost: 4465 \n", + "problem: kroA200 cost: 44181 \n", + "problem: ts225 cost: 210475 \n", + "Problem: tsp225 Cost: 6212 Optimal Cost: 3919 \t Gap: 58.51%\n", + "problem: tsp225 cost: 6212 \n", + "problem: pr226 cost: 98849 \n" + ] + } + ], + "source": [ + "# problem_files = problem_files_full # if you want to test on all the problems\n", + "problem_files = problem_files_custom # if you want to test on the customized problems\n", + "\n", + "for problem_idx in range(len(problem_files)):\n", + " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", + "\n", + " # NOTE: in some problem files (e.g. hk48), the node coordinates are not available\n", + " # we temporarily skip these problems\n", + " if not len(problem.node_coords):\n", + " continue\n", + "\n", + " # Load the node coordinates\n", + " coords = []\n", + " for _, v in problem.node_coords.items():\n", + " coords.append(v)\n", + " coords = torch.tensor(coords).float().to(device) # [n, 2]\n", + " coords_norm = normalize_coord(coords)\n", + "\n", + " # Prepare the tensordict\n", + " batch_size = 2\n", + " td = env.reset(batch_size=(batch_size,)).to(device)\n", + " td['locs'] = repeat(coords_norm, 'n d -> b n d', b=batch_size, d=2)\n", + " td['action_mask'] = torch.ones(batch_size, coords_norm.shape[0], dtype=torch.bool)\n", + "\n", + " # Get the solution from the policy\n", + " with torch.no_grad():\n", + " out = policy(\n", + " td.clone(), \n", + " decode_type=\"greedy\", \n", + " return_actions=True,\n", + " num_starts=0\n", + " )\n", + "\n", + " # Calculate the cost on the original scale\n", + " td['locs'] = repeat(coords, 'n d -> b n d', b=batch_size, d=2)\n", + " neg_reward = env.get_reward(td, out['actions'])\n", + " cost = ceil(-1 * neg_reward[0].item())\n", + "\n", + " # Check if there exists an optimal solution\n", + " try:\n", + " # Load the optimal solution\n", + " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n", + " matches = re.findall(r'\\((.*?)\\)', solution.comment)\n", + "\n", + " # NOTE: in some problem solution file (e.g. ch130), the optimal cost is not writen with a brace\n", + " # we temporarily skip to calculate the gap for these problems\n", + " optimal_cost = int(matches[0])\n", + " gap = (cost - optimal_cost) / optimal_cost\n", + " print(f'Problem: {problem_files[problem_idx][:-4]:<10} Cost: {cost:<10} Optimal Cost: {optimal_cost:<10}\\t Gap: {gap:.2%}')\n", + " except:\n", + " continue\n", + " finally:\n", + " print(f'problem: {problem_files[problem_idx][:-4]:<10} cost: {cost:<10}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test the Augmentation" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3883036/2898406631.py:13: DeprecationWarning: Call to deprecated function (or staticmethod) load_problem. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", + " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", + "/tmp/ipykernel_3883036/2898406631.py:56: DeprecationWarning: Call to deprecated function (or staticmethod) load_solution. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", + " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Problem: eil51\t Cost: 457\t Optimal Cost: 426\t Gap: 7.28%\n", + "problem: eil51\t cost: 457\t\n", + "problem: berlin52\t cost: 8256\t\n", + "Problem: st70\t Cost: 777\t Optimal Cost: 675\t Gap: 15.11%\n", + "problem: st70\t cost: 777\t\n", + "Problem: eil76\t Cost: 652\t Optimal Cost: 538\t Gap: 21.19%\n", + "problem: eil76\t cost: 652\t\n", + "Problem: pr76\t Cost: 124939\t Optimal Cost: 108159\t Gap: 15.51%\n", + "problem: pr76\t cost: 124939\t\n", + "problem: rat99\t cost: 1614\t\n", + "Problem: kroA100\t Cost: 27694\t Optimal Cost: 21282\t Gap: 30.13%\n", + "problem: kroA100\t cost: 27694\t\n", + "problem: kroB100\t cost: 28244\t\n", + "Problem: kroC100\t Cost: 25032\t Optimal Cost: 20749\t Gap: 20.64%\n", + "problem: kroC100\t cost: 25032\t\n", + "Problem: kroD100\t Cost: 26811\t Optimal Cost: 21294\t Gap: 25.91%\n", + "problem: kroD100\t cost: 26811\t\n", + "problem: kroE100\t cost: 27831\t\n", + "Problem: rd100\t Cost: 9657\t Optimal Cost: 7910\t Gap: 22.09%\n", + "problem: rd100\t cost: 9657\t\n", + "problem: eil101\t cost: 781\t\n", + "Problem: lin105\t Cost: 18769\t Optimal Cost: 14379\t Gap: 30.53%\n", + "problem: lin105\t cost: 18769\t\n", + "problem: pr124\t cost: 72115\t\n", + "problem: bier127\t cost: 154518\t\n", + "problem: ch130\t cost: 7543\t\n", + "problem: pr136\t cost: 128134\t\n", + "problem: pr144\t cost: 69755\t\n", + "problem: kroA150\t cost: 35967\t\n", + "problem: kroB150\t cost: 35196\t\n", + "problem: pr152\t cost: 88602\t\n", + "problem: u159\t cost: 59484\t\n", + "problem: rat195\t cost: 3631\t\n", + "problem: kroA200\t cost: 42061\t\n", + "problem: ts225\t cost: 196545\t\n", + "Problem: tsp225\t Cost: 5680\t Optimal Cost: 3919\t Gap: 44.93%\n", + "problem: tsp225\t cost: 5680\t\n", + "problem: pr226\t cost: 94540\t\n" + ] + } + ], + "source": [ + "# problem_files = problem_files_full # if you want to test on all the problems\n", + "problem_files = problem_files_custom # if you want to test on the customized problems\n", + "\n", + "# Import augmented utils\n", + "from rl4co.data.transforms import (\n", + " StateAugmentation as SymmetricStateAugmentation)\n", + "from rl4co.utils.ops import batchify, unbatchify\n", + "\n", + "num_augment = 100\n", + "augmentation = SymmetricStateAugmentation(env.name, num_augment=num_augment)\n", + "\n", + "for problem_idx in range(len(problem_files)):\n", + " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", + "\n", + " # NOTE: in some problem files (e.g. hk48), the node coordinates are not available\n", + " # we temporarily skip these problems\n", + " if not len(problem.node_coords):\n", + " continue\n", + "\n", + " # Load the node coordinates\n", + " coords = []\n", + " for _, v in problem.node_coords.items():\n", + " coords.append(v)\n", + " coords = torch.tensor(coords).float().to(device) # [n, 2]\n", + " coords_norm = normalize_coord(coords)\n", + "\n", + " # Prepare the tensordict\n", + " batch_size = 2\n", + " td = env.reset(batch_size=(batch_size,)).to(device)\n", + " td['locs'] = repeat(coords_norm, 'n d -> b n d', b=batch_size, d=2)\n", + " td['action_mask'] = torch.ones(batch_size, coords_norm.shape[0], dtype=torch.bool)\n", + "\n", + " # Augmentation\n", + " td = augmentation(td)\n", + "\n", + " # Get the solution from the policy\n", + " with torch.no_grad():\n", + " out = policy(\n", + " td.clone(), \n", + " decode_type=\"greedy\", \n", + " return_actions=True,\n", + " num_starts=0\n", + " )\n", + "\n", + " # Calculate the cost on the original scale\n", + " coords_repeat = repeat(coords, 'n d -> b n d', b=batch_size, d=2)\n", + " td['locs'] = batchify(coords_repeat, num_augment)\n", + " reward = env.get_reward(td, out['actions'])\n", + " reward = unbatchify(reward, num_augment)\n", + " cost = ceil(-1 * torch.max(reward).item())\n", + "\n", + " # Check if there exists an optimal solution\n", + " try:\n", + " # Load the optimal solution\n", + " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n", + " matches = re.findall(r'\\((.*?)\\)', solution.comment)\n", + "\n", + " # NOTE: in some problem solution file (e.g. ch130), the optimal cost is not writen with a brace\n", + " # we temporarily skip to calculate the gap for these problems\n", + " optimal_cost = int(matches[0])\n", + " gap = (cost - optimal_cost) / optimal_cost\n", + " print(f'Problem: {problem_files[problem_idx][:-4]}\\t Cost: {cost}\\t Optimal Cost: {optimal_cost}\\t Gap: {gap:.2%}')\n", + " except:\n", + " continue\n", + " finally:\n", + " print(f'problem: {problem_files[problem_idx][:-4]}\\t cost: {cost}\\t')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test the Sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3883036/2154301274.py:9: DeprecationWarning: Call to deprecated function (or staticmethod) load_problem. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", + " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", + "/tmp/ipykernel_3883036/2154301274.py:53: DeprecationWarning: Call to deprecated function (or staticmethod) load_solution. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", + " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Problem: eil51\t Cost: 482\t Optimal Cost: 426\t Gap: 13.15%\n", + "problem: eil51\t cost: 482\t\n", + "problem: berlin52\t cost: 8955\t\n", + "Problem: st70\t Cost: 794\t Optimal Cost: 675\t Gap: 17.63%\n", + "problem: st70\t cost: 794\t\n", + "Problem: eil76\t Cost: 673\t Optimal Cost: 538\t Gap: 25.09%\n", + "problem: eil76\t cost: 673\t\n", + "Problem: pr76\t Cost: 127046\t Optimal Cost: 108159\t Gap: 17.46%\n", + "problem: pr76\t cost: 127046\t\n", + "problem: rat99\t cost: 1886\t\n", + "Problem: kroA100\t Cost: 29517\t Optimal Cost: 21282\t Gap: 38.69%\n", + "problem: kroA100\t cost: 29517\t\n", + "problem: kroB100\t cost: 28892\t\n", + "Problem: kroC100\t Cost: 26697\t Optimal Cost: 20749\t Gap: 28.67%\n", + "problem: kroC100\t cost: 26697\t\n", + "Problem: kroD100\t Cost: 27122\t Optimal Cost: 21294\t Gap: 27.37%\n", + "problem: kroD100\t cost: 27122\t\n", + "problem: kroE100\t cost: 28016\t\n", + "Problem: rd100\t Cost: 10424\t Optimal Cost: 7910\t Gap: 31.78%\n", + "problem: rd100\t cost: 10424\t\n", + "problem: eil101\t cost: 837\t\n", + "Problem: lin105\t Cost: 19618\t Optimal Cost: 14379\t Gap: 36.44%\n", + "problem: lin105\t cost: 19618\t\n", + "problem: pr124\t cost: 74699\t\n", + "problem: bier127\t cost: 170255\t\n", + "problem: ch130\t cost: 7985\t\n", + "problem: pr136\t cost: 129964\t\n", + "problem: pr144\t cost: 70477\t\n", + "problem: kroA150\t cost: 37185\t\n", + "problem: kroB150\t cost: 35172\t\n", + "problem: pr152\t cost: 97244\t\n", + "problem: u159\t cost: 59792\t\n", + "problem: rat195\t cost: 4325\t\n", + "problem: kroA200\t cost: 42059\t\n", + "problem: ts225\t cost: 205982\t\n", + "Problem: tsp225\t Cost: 5970\t Optimal Cost: 3919\t Gap: 52.33%\n", + "problem: tsp225\t cost: 5970\t\n", + "problem: pr226\t cost: 103135\t\n" + ] + } + ], + "source": [ + "# problem_files = problem_files_full # if you want to test on all the problems\n", + "problem_files = problem_files_custom # if you want to test on the customized problems\n", + "\n", + "# Parameters for sampling\n", + "num_samples = 100\n", + "softmax_temp = 0.05\n", + "\n", + "for problem_idx in range(len(problem_files)):\n", + " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", + "\n", + " # NOTE: in some problem files (e.g. hk48), the node coordinates are not available\n", + " # we temporarily skip these problems\n", + " if not len(problem.node_coords):\n", + " continue\n", + "\n", + " # Load the node coordinates\n", + " coords = []\n", + " for _, v in problem.node_coords.items():\n", + " coords.append(v)\n", + " coords = torch.tensor(coords).float().to(device) # [n, 2]\n", + " coords_norm = normalize_coord(coords)\n", + "\n", + " # Prepare the tensordict\n", + " batch_size = 2\n", + " td = env.reset(batch_size=(batch_size,)).to(device)\n", + " td['locs'] = repeat(coords_norm, 'n d -> b n d', b=batch_size, d=2)\n", + " td['action_mask'] = torch.ones(batch_size, coords_norm.shape[0], dtype=torch.bool)\n", + "\n", + " # Sampling\n", + " td = batchify(td, num_samples)\n", + "\n", + " # Get the solution from the policy\n", + " with torch.no_grad():\n", + " out = policy(\n", + " td.clone(), \n", + " decode_type=\"sampling\", \n", + " return_actions=True,\n", + " num_starts=0,\n", + " softmax_temp=softmax_temp\n", + " )\n", + "\n", + " # Calculate the cost on the original scale\n", + " coords_repeat = repeat(coords, 'n d -> b n d', b=batch_size, d=2)\n", + " td['locs'] = batchify(coords_repeat, num_samples)\n", + " reward = env.get_reward(td, out['actions'])\n", + " reward = unbatchify(reward, num_samples)\n", + " cost = ceil(-1 * torch.max(reward).item())\n", + "\n", + " # Check if there exists an optimal solution\n", + " try:\n", + " # Load the optimal solution\n", + " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n", + " matches = re.findall(r'\\((.*?)\\)', solution.comment)\n", + "\n", + " # NOTE: in some problem solution file (e.g. ch130), the optimal cost is not writen with a brace\n", + " # we temporarily skip to calculate the gap for these problems\n", + " optimal_cost = int(matches[0])\n", + " gap = (cost - optimal_cost) / optimal_cost\n", + " print(f'Problem: {problem_files[problem_idx][:-4]}\\t Cost: {cost}\\t Optimal Cost: {optimal_cost}\\t Gap: {gap:.2%}')\n", + " except:\n", + " continue\n", + " finally:\n", + " print(f'problem: {problem_files[problem_idx][:-4]}\\t cost: {cost}\\t')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rl4co-user", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}