diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 21588fe3..b9bc85df 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -8,8 +8,10 @@ jobs:
fail-fast: true
max-parallel: 15
matrix:
+ # os: [ubuntu-latest, macos-latest, windows-latest, macos-13-xlarge]
+ # For Apple Silicon: https://github.com/actions/runner-images/issues/8439
os: [ubuntu-latest, macos-latest, windows-latest]
- python-version: ['3.8', '3.9', '3.10']
+ python-version: ['3.8', '3.9', '3.10', '3.11']
defaults:
run:
shell: bash
diff --git a/configs/env/cvrp.yaml b/configs/env/cvrp.yaml
index 2b98bd92..e8598620 100644
--- a/configs/env/cvrp.yaml
+++ b/configs/env/cvrp.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.cvrp.CVRPEnv
+_target_: rl4co.envs.CVRPEnv
name: cvrp
num_loc: 20
diff --git a/configs/env/default.yaml b/configs/env/default.yaml
index a9ce9bf3..66fa95a6 100644
--- a/configs/env/default.yaml
+++ b/configs/env/default.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.tsp.TSPEnv
+_target_: rl4co.envs.TSPEnv
name: tsp
num_loc: 20
diff --git a/configs/env/dpp.yaml b/configs/env/dpp.yaml
index 0456187c..51a3c74b 100644
--- a/configs/env/dpp.yaml
+++ b/configs/env/dpp.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.dpp.DPPEnv
+_target_: rl4co.envs.DPPEnv
name: dpp
max_decaps: 20
diff --git a/configs/env/mdpp.yaml b/configs/env/mdpp.yaml
index 57194a58..df790181 100644
--- a/configs/env/mdpp.yaml
+++ b/configs/env/mdpp.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.mdpp.MDPPEnv
+_target_: rl4co.envs.MDPPEnv
name: mdpp
max_decaps: 20
diff --git a/configs/env/mtsp.yaml b/configs/env/mtsp.yaml
index 50cadba5..e24e0dca 100644
--- a/configs/env/mtsp.yaml
+++ b/configs/env/mtsp.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.mtsp.MTSPEnv
+_target_: rl4co.envs.MTSPEnv
name: mtsp
num_loc: 20
diff --git a/configs/env/op.yaml b/configs/env/op.yaml
index e71bcba1..08d8d86d 100644
--- a/configs/env/op.yaml
+++ b/configs/env/op.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.op.OPEnv
+_target_: rl4co.envs.OPEnv
name: op
num_loc: 20
diff --git a/configs/env/pctsp.yaml b/configs/env/pctsp.yaml
index 3e92aac1..a05fc1f7 100644
--- a/configs/env/pctsp.yaml
+++ b/configs/env/pctsp.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.pctsp.PCTSPEnv
+_target_: rl4co.envs.PCTSPEnv
name: pctsp
num_loc: 20
diff --git a/configs/env/pdp.yaml b/configs/env/pdp.yaml
index 71f12e01..ba5236a9 100644
--- a/configs/env/pdp.yaml
+++ b/configs/env/pdp.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.pdp.PDPEnv
+_target_: rl4co.envs.PDPEnv
name: pdp
num_loc: 20
diff --git a/configs/env/sdvrp.yaml b/configs/env/sdvrp.yaml
index cb5f81a0..6ecdd4ce 100644
--- a/configs/env/sdvrp.yaml
+++ b/configs/env/sdvrp.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.sdvrp.SDVRPEnv
+_target_: rl4co.envs.SDVRPEnv
name: sdvrp
num_loc: 20
diff --git a/configs/env/spctsp.yaml b/configs/env/spctsp.yaml
index ac1387e6..1a239237 100644
--- a/configs/env/spctsp.yaml
+++ b/configs/env/spctsp.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.spctsp.SPCTSPEnv
+_target_: rl4co.envs.SPCTSPEnv
name: spctsp
num_loc: 20
diff --git a/configs/env/tsp.yaml b/configs/env/tsp.yaml
index 6d3b87fc..e853a203 100644
--- a/configs/env/tsp.yaml
+++ b/configs/env/tsp.yaml
@@ -1,4 +1,4 @@
-_target_: rl4co.envs.tsp.TSPEnv
+_target_: rl4co.envs.TSPEnv
name: tsp
diff --git a/configs/experiment/base.yaml b/configs/experiment/base.yaml
index 4fae8d25..4cfe47d0 100644
--- a/configs/experiment/base.yaml
+++ b/configs/experiment/base.yaml
@@ -17,9 +17,6 @@ defaults:
# that are automatically generated with seed following Kool et al. (2019).
env:
num_loc: 50
- data_dir: ${paths.root_dir}/data/tsp
- val_file: tsp${env.num_loc}_val_seed4321.npz
- test_file: tsp${env.num_loc}_test_seed1234.npz
# Logging: we use Wandb in this case
logger:
diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml
index 84df21f4..e3344212 100644
--- a/configs/trainer/default.yaml
+++ b/configs/trainer/default.yaml
@@ -4,16 +4,9 @@ _target_: rl4co.utils.trainer.RL4COTrainer
default_root_dir: ${paths.output_dir}
gradient_clip_val: 1.0
-accelerator: "gpu"
+accelerator: "auto"
precision: "16-mixed"
-# Fast distributed training: comment out to use on single GPU
-# devices: 1 # change number of devices
-strategy:
- _target_: lightning.pytorch.strategies.DDPStrategy
- find_unused_parameters: True
- gradient_as_bucket_view: True
-
# perform a validation loop every N training epochs
check_val_every_n_epoch: 1
diff --git a/docs/_theme/rl4co/extensions/lightning.py b/docs/_theme/rl4co/extensions/lightning.py
index d0633b41..8cba7fda 100644
--- a/docs/_theme/rl4co/extensions/lightning.py
+++ b/docs/_theme/rl4co/extensions/lightning.py
@@ -13,10 +13,7 @@
# limitations under the License.
from docutils import nodes
from docutils.statemachine import StringList
-from sphinx.util.docutils import SphinxDirective
-
from pt_lightning_sphinx_theme.extensions.pytorch_tutorials import (
- cardnode,
CustomCalloutItemDirective,
CustomCardItemDirective,
DisplayItemDirective,
@@ -24,7 +21,9 @@
ReactGreeter,
SlackButton,
TwoColumns,
+ cardnode,
)
+from sphinx.util.docutils import SphinxDirective
class tutoriallistnode(nodes.General, nodes.Element):
diff --git a/docs/_theme/rl4co/extensions/pytorch_tutorials.py b/docs/_theme/rl4co/extensions/pytorch_tutorials.py
index 97dbc6e6..12b0e73d 100644
--- a/docs/_theme/rl4co/extensions/pytorch_tutorials.py
+++ b/docs/_theme/rl4co/extensions/pytorch_tutorials.py
@@ -34,9 +34,8 @@
from docutils import nodes
from docutils.parsers.rst import Directive, directives
from docutils.statemachine import StringList
-from sphinx.util.docutils import SphinxDirective
-
from pt_lightning_sphinx_theme.extensions.react import get_react_component_rst
+from sphinx.util.docutils import SphinxDirective
try:
FileNotFoundError
@@ -272,11 +271,23 @@ def run(self):
image_class = ""
if "image_center" in self.options:
- image = ""
+ image = (
+ ""
+ )
image_class = "image-center"
elif "image_right" in self.options:
- image = ""
+ image = (
+ ""
+ )
image_class = "image-right"
else:
image = ""
@@ -371,7 +382,11 @@ def run(self):
raise
# return []
callout_rst = get_react_component_rst(
- "LikeButtonWithTitle", width=width, margin=margin, title=title, padding=padding
+ "LikeButtonWithTitle",
+ width=width,
+ margin=margin,
+ title=title,
+ padding=padding,
)
callout_list = StringList(callout_rst.split("\n"))
callout = nodes.paragraph()
@@ -427,7 +442,9 @@ def run(self):
print(e)
raise
return []
- callout_rst = SLACK_TEMPLATE.format(align=align, title=title, margin=margin, width=width)
+ callout_rst = SLACK_TEMPLATE.format(
+ align=align, title=title, margin=margin, width=width
+ )
callout_list = StringList(callout_rst.split("\n"))
callout = nodes.paragraph()
self.state.nested_parse(callout_list, self.content_offset, callout)
diff --git a/pyproject.toml b/pyproject.toml
index 8b4feee5..bf0fbcf0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,14 +23,14 @@ dynamic = ["version"]
license = {file = "LICENSE"}
-# TODO: allow new Python versions https://github.com/kaist-silab/rl4co/issues/95
-requires-python = ">=3.8, <3.11" # https://github.com/kaist-silab/rl4co/issues/90
+requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
@@ -41,10 +41,9 @@ classifiers = [
# TODO: allow new TorchRL / TensorDict versions https://github.com/kaist-silab/rl4co/issues/95
dependencies = [
- "torch>=2.0.0,<2.1.0", # Possibly TorchRL problem on Windows with older version
- "torchrl==0.1.1",
- "tensordict==0.1.1",
- "lightning>=2.0.5",
+ "torchrl>=0.2.0",
+ "tensordict>=0.2.0",
+ "lightning>=2.1.0",
"hydra-core",
"hydra-colorlog",
"omegaconf",
diff --git a/rl4co/__init__.py b/rl4co/__init__.py
index 4dc8ce10..95407eb1 100644
--- a/rl4co/__init__.py
+++ b/rl4co/__init__.py
@@ -1 +1 @@
-__version__ = "0.2.4.dev1"
+__version__ = "0.3.0dev0"
diff --git a/rl4co/data/dataset.py b/rl4co/data/dataset.py
index 5d22c603..afcf8bee 100644
--- a/rl4co/data/dataset.py
+++ b/rl4co/data/dataset.py
@@ -1,3 +1,4 @@
+
import torch
from tensordict.tensordict import TensorDict
@@ -5,9 +6,40 @@
class TensorDictDataset(Dataset):
+ """Dataset compatible with TensorDicts.
+ Uses more CPU and has similar performance in loading to list comprehension, but is faster in instantiation
+ than :class:`TensorDictDatasetList` (more than 10x faster).
+ """
+
+ def __init__(self, td: TensorDict):
+ self.data = td
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitems__(self, index):
+ # Tricks:
+ # - batched data loading with `__getitems__` for faster loading
+ # - avoid directly indexing TensorDicts for faster loading
+ return TensorDict(
+ {key: item[index] for key, item in self.data.items()},
+ batch_size=torch.Size([len(index)]),
+ _run_checks=False, # faster this way
+ )
+
+ def add_key(self, key, value):
+ return self.data.update({key: value}) # native method
+
+
+def tensordict_collate_fn(x):
+ """Equivalent to collating with `lambda x: x`"""
+ return x
+
+
+class TensorDictDatasetList(Dataset):
"""Dataset compatible with TensorDicts.
It is better to "disassemble" the TensorDict into a list of dicts.
- See :class:`tensordict_collate_fn` for more details.
+ See :class:`tensordict_collate_fn_list` for more details.
Note:
Check out the issue on tensordict for more details:
@@ -16,21 +48,24 @@ class TensorDictDataset(Dataset):
but uses > 3x more CPU.
"""
- def __init__(self, data: TensorDict):
+ def __init__(self, td: TensorDict):
+ self.data_len = td.batch_size[0]
self.data = [
- {key: value[i] for key, value in data.items()} for i in range(data.shape[0])
+ {key: value[i] for key, value in td.items()} for i in range(self.data_len)
]
def __len__(self):
- return len(self.data)
+ return self.data_len
def __getitem__(self, idx):
return self.data[idx]
+ def add_key(self, key, value):
+ return ExtraKeyDataset(self, value, key_name=key)
-def tensordict_collate_fn(batch):
- """Collate function compatible with TensorDicts.
- Reassemble the list of dicts into a TensorDict; seems to be way more efficient than using a TensorDictDataset.
+
+def tensordict_collate_fn_list(batch):
+ """Collate function compatible with TensorDicts that reassembles a list of dicts.
Note:
Check out the issue on tensordict for more details:
@@ -40,24 +75,28 @@ def tensordict_collate_fn(batch):
"""
return TensorDict(
{key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
- batch_size=len(batch),
+ batch_size=torch.Size([len(batch)]),
+ device=batch[0].device,
+ _run_checks=False,
)
-class ExtraKeyDataset(Dataset):
+class ExtraKeyDataset(TensorDictDatasetList):
"""Dataset that includes an extra key to add to the data dict.
This is useful for adding a REINFORCE baseline reward to the data dict.
+ Note that this is faster to instantiate than using list comprehension.
"""
- def __init__(self, dataset: TensorDictDataset, extra: torch.Tensor):
+ def __init__(
+ self, dataset: TensorDictDatasetList, extra: torch.Tensor, key_name="extra"
+ ):
+ self.data_len = len(dataset)
+ assert self.data_len == len(extra), "Data and extra must be same length"
self.data = dataset.data
self.extra = extra
- assert len(self.data) == len(self.extra), "Data and extra must be same length"
-
- def __len__(self):
- return len(self.data)
+ self.key_name = key_name
def __getitem__(self, idx):
data = self.data[idx]
- data["extra"] = self.extra[idx]
+ data[self.key_name] = self.extra[idx]
return data
diff --git a/rl4co/envs/__init__.py b/rl4co/envs/__init__.py
index 26eeecb0..1b4fce9f 100644
--- a/rl4co/envs/__init__.py
+++ b/rl4co/envs/__init__.py
@@ -1,25 +1,31 @@
# Base environment
-# Main Environments
-from rl4co.envs.atsp import ATSPEnv
from rl4co.envs.common.base import RL4COEnvBase
-from rl4co.envs.cvrp import CVRPEnv
-from rl4co.envs.dpp import DPPEnv
-from rl4co.envs.ffsp import FFSPEnv
-from rl4co.envs.mdpp import MDPPEnv
-from rl4co.envs.mtsp import MTSPEnv
-from rl4co.envs.op import OPEnv
-from rl4co.envs.pctsp import PCTSPEnv
-from rl4co.envs.pdp import PDPEnv
-from rl4co.envs.sdvrp import SDVRPEnv
-from rl4co.envs.smtwtp import SMTWTPEnv
-from rl4co.envs.spctsp import SPCTSPEnv
-from rl4co.envs.tsp import TSPEnv
+
+# EDA
+from rl4co.envs.eda import DPPEnv, MDPPEnv
+
+# Routing
+from rl4co.envs.routing import (
+ ATSPEnv,
+ CVRPEnv,
+ MTSPEnv,
+ OPEnv,
+ PCTSPEnv,
+ PDPEnv,
+ SDVRPEnv,
+ SPCTSPEnv,
+ TSPEnv,
+)
+
+# Scheduling
+from rl4co.envs.scheduling import FFSPEnv, SMTWTPEnv
# Register environments
ENV_REGISTRY = {
"atsp": ATSPEnv,
"cvrp": CVRPEnv,
"dpp": DPPEnv,
+ "ffsp": FFSPEnv,
"mdpp": MDPPEnv,
"mtsp": MTSPEnv,
"op": OPEnv,
diff --git a/rl4co/envs/common/base.py b/rl4co/envs/common/base.py
index 1607d413..85eb52a1 100644
--- a/rl4co/envs/common/base.py
+++ b/rl4co/envs/common/base.py
@@ -1,5 +1,5 @@
from os.path import join as pjoin
-from typing import Optional, Iterable
+from typing import Iterable, Optional
import torch
@@ -40,6 +40,7 @@ def __init__(
val_dataloader_names: list = None,
test_dataloader_names: list = None,
check_solution: bool = True,
+ _torchrl_mode: bool = False, # TODO
seed: int = None,
device: str = "cpu",
**kwargs,
@@ -47,6 +48,7 @@ def __init__(
super().__init__(device=device, batch_size=[])
self.data_dir = data_dir
self.train_file = pjoin(data_dir, train_file) if train_file is not None else None
+ self._torchrl_mode = _torchrl_mode
def get_files(f):
if f is not None:
@@ -85,6 +87,41 @@ def get_multiple_dataloader_names(f, names):
seed = torch.empty((), dtype=torch.int64).random_().item()
self.set_seed(seed)
+ def step(self, td: TensorDict) -> TensorDict:
+ """Step function to call at each step of the episode containing an action.
+ If `_torchrl_mode` is True, we call `_torchrl_step` instead which set the
+ `next` key of the TensorDict to the next state - this is the usual way to do it in TorchRL,
+ but inefficient in our case
+ """
+ if not self._torchrl_mode:
+ # Default: just return the TensorDict without farther checks etc is faster
+ td = self._step(td)
+ return {"next": td}
+ else:
+ # Since we simplify the syntax
+ return self._torchrl_step(td)
+
+ def _torchrl_step(self, td: TensorDict) -> TensorDict:
+ """See :meth:`super().step` for more details.
+ This is the usual way to do it in TorchRL, but inefficient in our case
+
+ Note:
+ Here we clone the TensorDict to avoid recursion error, since we allow
+ for directly updating the TensorDict in the step function
+ """
+ # sanity check
+ self._assert_tensordict_shape(td)
+ next_preset = td.get("next", None)
+
+ next_tensordict = self._step(
+ td.clone()
+ ) # NOTE: we clone to avoid recursion error
+ next_tensordict = self._step_proc_data(next_tensordict)
+ if next_preset is not None:
+ next_tensordict.update(next_preset.exclude(*next_tensordict.keys(True, True)))
+ td.set("next", next_tensordict)
+ return td
+
def _step(self, td: TensorDict) -> TensorDict:
"""Step function to call at each step of the episode containing an action.
Gives the next observation, reward, done
@@ -178,6 +215,13 @@ def _set_seed(self, seed: Optional[int]):
rng = torch.manual_seed(seed)
self.rng = rng
+ def to(self, device):
+ """Override `to` device method for safety against `None` device (may be found in `TensorDict`))"""
+ if device is None:
+ return self
+ else:
+ return super().to(device)
+
def __getstate__(self):
"""Return the state of the environment. By default, we want to avoid pickling
the random number generator directly as it is not allowed by `deepcopy`
diff --git a/rl4co/envs/eda/__init__.py b/rl4co/envs/eda/__init__.py
new file mode 100644
index 00000000..da7f45e2
--- /dev/null
+++ b/rl4co/envs/eda/__init__.py
@@ -0,0 +1,2 @@
+from rl4co.envs.eda.dpp import DPPEnv
+from rl4co.envs.eda.mdpp import MDPPEnv
diff --git a/rl4co/envs/dpp.py b/rl4co/envs/eda/dpp.py
similarity index 95%
rename from rl4co/envs/dpp.py
rename to rl4co/envs/eda/dpp.py
index 8ac8a857..fe88572c 100644
--- a/rl4co/envs/dpp.py
+++ b/rl4co/envs/eda/dpp.py
@@ -100,30 +100,25 @@ def _step(self, td: TensorDict) -> TensorDict:
# Set done if i is greater than max_decaps
done = td["i"] >= self.max_decaps - 1
- # Calculate reward (we set to -inf since we calculate the reward outside based on the actions)
- reward = torch.ones_like(done) * float("-inf")
+ # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
+ reward = torch.zeros_like(done)
- # The output must be written in a ``"next"`` entry
- return TensorDict(
+ td.update(
{
- "next": {
- "locs": td["locs"],
- "probe": td["probe"],
- "i": td["i"] + 1,
- "action_mask": available,
- "keepout": td["keepout"],
- "reward": reward,
- "done": done,
- }
- },
- td.shape,
+ "i": td["i"] + 1,
+ "action_mask": available,
+ "reward": reward,
+ "done": done,
+ }
)
+ return td
def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
# Initialize locations
if batch_size is None:
batch_size = self.batch_size if td is None else td.batch_size
- self.device = td.device if td is not None else self.device
+ device = td.device if td is not None else self.device
+ self.to(device)
# We allow loading the initial observation from a dataset for faster loading
if td is None:
@@ -170,7 +165,6 @@ def _make_spec(self, td_params):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
@@ -297,7 +291,7 @@ def _initial_impedance(self, probe):
return zout
def _decap_simulator(self, probe, solution, keepout=None):
- self.device = solution.device
+ self.to(self.device)
probe = probe.item()
diff --git a/rl4co/envs/mdpp.py b/rl4co/envs/eda/mdpp.py
similarity index 98%
rename from rl4co/envs/mdpp.py
rename to rl4co/envs/eda/mdpp.py
index 633cd5e9..25948fe0 100644
--- a/rl4co/envs/mdpp.py
+++ b/rl4co/envs/eda/mdpp.py
@@ -11,7 +11,7 @@
UnboundedDiscreteTensorSpec,
)
-from rl4co.envs.dpp import DPPEnv
+from rl4co.envs.eda.dpp import DPPEnv
from rl4co.utils.pylogger import get_pylogger
log = get_pylogger(__name__)
@@ -64,8 +64,12 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict
# Action mask is 0 if both action_mask (e.g. keepout) and probe are 0
action_mask = torch.logical_and(td_reset["action_mask"], ~td_reset["probe"])
# Keepout regions are the inverse of action_mask
- td_reset.set_("keepout", ~td_reset["action_mask"])
- td_reset.set_("action_mask", action_mask)
+ td_reset.update(
+ {
+ "keepout": ~td_reset["action_mask"],
+ "action_mask": action_mask,
+ }
+ )
return td_reset
def _make_spec(self, td_params):
@@ -95,7 +99,6 @@ def _make_spec(self, td_params):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/routing/__init__.py b/rl4co/envs/routing/__init__.py
new file mode 100644
index 00000000..d7b815d5
--- /dev/null
+++ b/rl4co/envs/routing/__init__.py
@@ -0,0 +1,9 @@
+from rl4co.envs.routing.atsp import ATSPEnv
+from rl4co.envs.routing.cvrp import CVRPEnv
+from rl4co.envs.routing.mtsp import MTSPEnv
+from rl4co.envs.routing.op import OPEnv
+from rl4co.envs.routing.pctsp import PCTSPEnv
+from rl4co.envs.routing.pdp import PDPEnv
+from rl4co.envs.routing.sdvrp import SDVRPEnv
+from rl4co.envs.routing.spctsp import SPCTSPEnv
+from rl4co.envs.routing.tsp import TSPEnv
diff --git a/rl4co/envs/atsp.py b/rl4co/envs/routing/atsp.py
similarity index 90%
rename from rl4co/envs/atsp.py
rename to rl4co/envs/routing/atsp.py
index fe451db9..dcfb720e 100644
--- a/rl4co/envs/atsp.py
+++ b/rl4co/envs/routing/atsp.py
@@ -20,7 +20,7 @@
class ATSPEnv(RL4COEnvBase):
"""
Asymmetric Traveling Salesman Problem environment
- At each step, the agent chooses a city to visit. The reward is the -infinite unless the agent visits all the cities.
+ At each step, the agent chooses a city to visit. The reward is 0 unless the agent visits all the cities.
In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length.
Unlike the TSP, the distance matrix is asymmetric, i.e., the distance from A to B is not necessarily the same as the distance from B to A.
@@ -62,24 +62,20 @@ def _step(td: TensorDict) -> TensorDict:
# We are done there are no unvisited locations
done = torch.count_nonzero(available, dim=-1) <= 0
- # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here
- reward = torch.ones_like(done) * float("-inf")
+ # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
+ reward = torch.zeros_like(done)
- # The output must be written in a ``"next"`` entry
- return TensorDict(
+ td.update(
{
- "next": {
- "cost_matrix": td["cost_matrix"],
- "first_node": first_node,
- "current_node": current_node,
- "i": td["i"] + 1,
- "action_mask": available,
- "reward": reward,
- "done": done,
- }
+ "first_node": first_node,
+ "current_node": current_node,
+ "i": td["i"] + 1,
+ "action_mask": available,
+ "reward": reward,
+ "done": done,
},
- td.shape,
)
+ return td
def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
# Initialize distance matrix
@@ -90,9 +86,8 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict
batch_size = (
self.batch_size if cost_matrix is None else cost_matrix.shape[:-2]
)
- self.device = device = (
- cost_matrix.device if cost_matrix is not None else self.device
- )
+ device = cost_matrix.device if cost_matrix is not None else self.device
+ self.to(device)
if cost_matrix is None:
cost_matrix = self.generate_data(batch_size=batch_size).to(device)[
"cost_matrix"
@@ -142,7 +137,6 @@ def _make_spec(self, td_params: TensorDict = None):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/cvrp.py b/rl4co/envs/routing/cvrp.py
similarity index 92%
rename from rl4co/envs/cvrp.py
rename to rl4co/envs/routing/cvrp.py
index 2cde5f51..edd79d38 100644
--- a/rl4co/envs/cvrp.py
+++ b/rl4co/envs/routing/cvrp.py
@@ -41,7 +41,7 @@ class CVRPEnv(RL4COEnvBase):
"""Capacitated Vehicle Routing Problem (CVRP) environment.
At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity.
When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to
- visit any customer, the agent must go back to the depot. The reward is the -infinite unless the agent visits all the cities.
+ visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities.
In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length.
Args:
@@ -103,25 +103,19 @@ def _step(self, td: TensorDict) -> TensorDict:
# SECTION: get done
done = visited.sum(-1) == visited.size(-1)
- reward = torch.ones_like(done) * float("-inf")
+ reward = torch.zeros_like(done)
- td_step = TensorDict(
+ td.update(
{
- "next": {
- "locs": td["locs"],
- "demand": td["demand"],
- "current_node": current_node,
- "used_capacity": used_capacity,
- "vehicle_capacity": td["vehicle_capacity"],
- "visited": visited,
- "reward": reward,
- "done": done,
- }
- },
- td.shape,
+ "current_node": current_node,
+ "used_capacity": used_capacity,
+ "visited": visited,
+ "reward": reward,
+ "done": done,
+ }
)
- td_step["next"].set("action_mask", self.get_action_mask(td_step["next"]))
- return td_step
+ td.set("action_mask", self.get_action_mask(td))
+ return td
def _reset(
self,
@@ -134,7 +128,7 @@ def _reset(
td = self.generate_data(batch_size=batch_size)
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
- self.device = td.device
+ self.to(td.device)
# Create reset TensorDict
td_reset = TensorDict(
@@ -226,15 +220,11 @@ def generate_data(self, batch_size) -> TensorDict:
# Demand sampling Following Kool et al. (2019)
# Generates a slightly different distribution than using torch.randint
demand = (
- (
- torch.FloatTensor(*batch_size, self.num_loc)
- .uniform_(self.min_demand - 1, self.max_demand - 1)
- .int()
- + 1
- )
- .float()
- .to(self.device)
- )
+ torch.FloatTensor(*batch_size, self.num_loc, device=self.device)
+ .uniform_(self.min_demand - 1, self.max_demand - 1)
+ .int()
+ + 1
+ ).float()
# Support for heterogeneous capacity if provided
if not isinstance(self.capacity, torch.Tensor):
@@ -250,6 +240,7 @@ def generate_data(self, batch_size) -> TensorDict:
"capacity": capacity,
},
batch_size=batch_size,
+ device=self.device,
)
@staticmethod
@@ -258,7 +249,7 @@ def load_data(fpath, batch_size=[]):
Normalize demand by capacity to be in [0, 1]
"""
td_load = load_npz_to_tensordict(fpath)
- td_load.set_("demand", td_load["demand"] / td_load["capacity"][:, None])
+ td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None])
return td_load
def _make_spec(self, td_params: TensorDict):
@@ -286,7 +277,6 @@ def _make_spec(self, td_params: TensorDict):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/mpdp.py b/rl4co/envs/routing/mpdp.py
similarity index 93%
rename from rl4co/envs/mpdp.py
rename to rl4co/envs/routing/mpdp.py
index 77e6c7cc..1a83eeb0 100644
--- a/rl4co/envs/mpdp.py
+++ b/rl4co/envs/routing/mpdp.py
@@ -22,7 +22,7 @@ class MPDPEnv(RL4COEnvBase):
The goal is to pick up and deliver all the packages while satisfying the precedence constraints.
When an agent goes back to the depot, a new agent is spawned. In the min-max version, the goal is to minimize the
maximum tour length among all agents.
- The reward is the -infinite unless the agent visits all the cities.
+ The reward is 0 unless the agent visits all the cities.
In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length.
Args:
@@ -111,36 +111,25 @@ def _step(self, td: TensorDict) -> TensorDict:
# Get done and reward
done = visited.all(dim=-1, keepdim=True).squeeze(-1)
- reward = torch.ones_like(done) * float(
- "-inf"
- ) # reward calculated via `get_reward` for now
+ reward = torch.zeros_like(done)
- td_step = TensorDict(
+ td.update(
{
- "next": {
- "locs": td["locs"],
- "visited": visited,
- "lengths": td["lengths"],
- "count_depot": td["count_depot"],
- "agent_idx": agent_idx,
- "cur_coord": cur_coord,
- "to_delivery": to_delivery,
- "left_request": td["left_request"],
- "depot_distance": depot_distance,
- "remain_sum_paired_distance": remain_sum_paired_distance,
- "remain_pickup_max_distance": remain_pickup_max_distance,
- "remain_delivery_max_distance": remain_delivery_max_distance,
- "add_pd_distance": td["add_pd_distance"],
- "longest_lengths": td["longest_lengths"],
- "i": td["i"] + 1,
- "done": done,
- "reward": reward,
- }
- },
- td.shape,
+ "visited": visited,
+ "agent_idx": agent_idx,
+ "cur_coord": cur_coord,
+ "to_delivery": to_delivery,
+ "depot_distance": depot_distance,
+ "remain_sum_paired_distance": remain_sum_paired_distance,
+ "remain_pickup_max_distance": remain_pickup_max_distance,
+ "remain_delivery_max_distance": remain_delivery_max_distance,
+ "i": td["i"] + 1,
+ "done": done,
+ "reward": reward,
+ }
)
- td_step["next"].set("action_mask", self.get_action_mask(td_step["next"]))
- return td_step
+ td.set("action_mask", self.get_action_mask(td))
+ return td
def _reset(
self,
@@ -154,7 +143,7 @@ def _reset(
if td is None or td.is_empty():
td = self.generate_data(batch_size=batch_size)
- self.device = td.device
+ self.to(td.device)
# NOTE: this is a hack to get the agent_num
# agent_num = td["agent_num"][0].item() if agent_num is None else agent_num
@@ -425,7 +414,6 @@ def _make_spec(self, td_params: TensorDict):
dtype=torch.int64,
),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/mtsp.py b/rl4co/envs/routing/mtsp.py
similarity index 94%
rename from rl4co/envs/mtsp.py
rename to rl4co/envs/routing/mtsp.py
index e85fa624..7b835589 100644
--- a/rl4co/envs/mtsp.py
+++ b/rl4co/envs/routing/mtsp.py
@@ -112,26 +112,22 @@ def _step(td: TensorDict) -> TensorDict:
# The reward is the negative of the max_subtour_length (minmax objective)
reward = -max_subtour_length
- # The output must be written in a ``"next"`` entry
- return TensorDict(
+ td.update(
{
- "next": {
- "locs": td["locs"],
- "num_agents": td["num_agents"],
- "max_subtour_length": max_subtour_length,
- "current_length": current_length,
- "agent_idx": cur_agent_idx,
- "first_node": first_node,
- "current_node": current_node,
- "i": td["i"] + 1,
- "action_mask": available,
- "reward": reward,
- "done": done,
- }
- },
- td.shape,
+ "max_subtour_length": max_subtour_length,
+ "current_length": current_length,
+ "agent_idx": cur_agent_idx,
+ "first_node": first_node,
+ "current_node": current_node,
+ "i": td["i"] + 1,
+ "action_mask": available,
+ "reward": reward,
+ "done": done,
+ }
)
+ return td
+
def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
# Initialize data
if batch_size is None:
@@ -214,7 +210,6 @@ def _make_spec(self, td_params: TensorDict):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/op.py b/rl4co/envs/routing/op.py
similarity index 94%
rename from rl4co/envs/op.py
rename to rl4co/envs/routing/op.py
index 42d00a9b..70f06263 100644
--- a/rl4co/envs/op.py
+++ b/rl4co/envs/routing/op.py
@@ -90,29 +90,23 @@ def _step(self, td: TensorDict) -> TensorDict:
# Done if went back to depot (except if it's the first step, since we start at the depot)
done = (current_node.squeeze(-1) == 0) & (td["i"] > 0)
- # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here
- reward = torch.ones_like(done) * float("-inf")
+ # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
+ reward = torch.zeros_like(done)
- td_step = TensorDict(
+ td.update(
{
- "next": {
- "locs": td["locs"],
- "prize": td["prize"],
- "tour_length": tour_length,
- "current_loc": current_loc,
- "max_length": td["max_length"],
- "current_node": current_node,
- "visited": visited,
- "current_total_prize": current_total_prize,
- "i": td["i"] + 1,
- "reward": reward,
- "done": done,
- }
- },
- td.shape,
+ "tour_length": tour_length,
+ "current_loc": current_loc,
+ "current_node": current_node,
+ "visited": visited,
+ "current_total_prize": current_total_prize,
+ "i": td["i"] + 1,
+ "reward": reward,
+ "done": done,
+ }
)
- td_step["next"].set("action_mask", self.get_action_mask(td_step["next"]))
- return td_step
+ td.set("action_mask", self.get_action_mask(td))
+ return td
def _reset(
self,
@@ -124,8 +118,7 @@ def _reset(
batch_size = self.batch_size if td is None else td["locs"].shape[:-2]
if td is None or td.is_empty():
td = self.generate_data(batch_size=batch_size)
- self.device = td.device
-
+ self.to(td.device)
# Add depot to locs
locs_with_depot = torch.cat((td["depot"][:, None, :], td["locs"]), -2)
@@ -322,7 +315,6 @@ def _make_spec(self, td_params: TensorDict):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/pctsp.py b/rl4co/envs/routing/pctsp.py
similarity index 93%
rename from rl4co/envs/pctsp.py
rename to rl4co/envs/routing/pctsp.py
index b4c3204f..ca0f4863 100644
--- a/rl4co/envs/pctsp.py
+++ b/rl4co/envs/routing/pctsp.py
@@ -78,31 +78,26 @@ def _step(self, td: TensorDict) -> TensorDict:
# Update visited
visited = td["visited"].scatter(-1, current_node[..., None], 1)
- # Done and reward. Calculation is done outside hence set -inf
+ # Done and reward
done = (td["i"] > 0) & (current_node == 0)
- reward = torch.ones_like(cur_total_prize) * float("-inf")
- td_step = TensorDict(
+ # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
+ reward = torch.zeros_like(done)
+
+ # Update state
+ td.update(
{
- "next": {
- "locs": td["locs"],
- "current_node": current_node,
- "expected_prize": td["expected_prize"],
- "real_prize": td["real_prize"],
- "penalty": td["penalty"],
- "cur_total_prize": cur_total_prize,
- "cur_total_penalty": cur_total_penalty,
- "visited": visited,
- "prize_required": td["prize_required"],
- "i": td["i"] + 1,
- "reward": reward,
- "done": done,
- },
- },
- batch_size=td.batch_size,
+ "current_node": current_node,
+ "cur_total_prize": cur_total_prize,
+ "cur_total_penalty": cur_total_penalty,
+ "visited": visited,
+ "i": td["i"] + 1,
+ "reward": reward,
+ "done": done,
+ }
)
- td_step["next"].set("action_mask", self.get_action_mask(td_step["next"]))
- return td_step
+ td.set("action_mask", self.get_action_mask(td))
+ return td
def _reset(
self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None
@@ -111,7 +106,7 @@ def _reset(
batch_size = self.batch_size if td is None else td["locs"].shape[:-2]
if td is None or td.is_empty():
td = self.generate_data(batch_size=batch_size)
- self.device = td.device
+ self.to(td.device)
locs = torch.cat([td["depot"][..., None, :], td["locs"]], dim=-2)
expected_prize = td["deterministic_prize"]
@@ -323,7 +318,6 @@ def _make_spec(self, td_params: TensorDict):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/pdp.py b/rl4co/envs/routing/pdp.py
similarity index 93%
rename from rl4co/envs/pdp.py
rename to rl4co/envs/routing/pdp.py
index 7d6a309d..b845e093 100644
--- a/rl4co/envs/pdp.py
+++ b/rl4co/envs/routing/pdp.py
@@ -71,25 +71,22 @@ def _step(td: TensorDict) -> TensorDict:
# We are done there are no unvisited locations
done = torch.count_nonzero(available, dim=-1) == 0
- # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here
- reward = torch.ones_like(done) * float("-inf")
+ # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
+ reward = torch.zeros_like(done)
- # The output must be written in a ``"next"`` entry
- return TensorDict(
+ # Update step
+ td.update(
{
- "next": {
- "locs": td["locs"],
- "current_node": current_node,
- "available": available,
- "to_deliver": to_deliver,
- "i": td["i"] + 1,
- "action_mask": action_mask,
- "reward": reward,
- "done": done,
- }
- },
- td.shape,
+ "current_node": current_node,
+ "available": available,
+ "to_deliver": to_deliver,
+ "i": td["i"] + 1,
+ "action_mask": action_mask,
+ "reward": reward,
+ "done": done,
+ }
)
+ return td
def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
if batch_size is None:
@@ -98,7 +95,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict
if td is None or td.is_empty():
td = self.generate_data(batch_size=batch_size)
- self.device = td.device
+ self.to(td.device)
locs = torch.cat((td["depot"][:, None, :], td["locs"]), -2)
@@ -170,7 +167,6 @@ def _make_spec(self, td_params: TensorDict):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/sdvrp.py b/rl4co/envs/routing/sdvrp.py
similarity index 89%
rename from rl4co/envs/sdvrp.py
rename to rl4co/envs/routing/sdvrp.py
index 30d582ee..1fbf182d 100644
--- a/rl4co/envs/sdvrp.py
+++ b/rl4co/envs/routing/sdvrp.py
@@ -10,10 +10,11 @@
UnboundedDiscreteTensorSpec,
)
-from rl4co.envs.cvrp import CVRPEnv
from rl4co.utils.ops import gather_by_index
from rl4co.utils.pylogger import get_pylogger
+from .cvrp import CVRPEnv
+
log = get_pylogger(__name__)
@@ -84,27 +85,24 @@ def _step(self, td: TensorDict) -> TensorDict:
-1, current_node, -delivered_demand
)
- # Get done and reward (-inf since we get it outside)
+ # Get done
done = ~(demand_with_depot > 0).any(-1)
- reward = torch.ones_like(done) * float("-inf")
- td_step = TensorDict(
+ # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
+ reward = torch.zeros_like(done)
+
+ # Update state
+ td.update(
{
- "next": {
- "locs": td["locs"],
- "demand": td["demand"],
- "demand_with_depot": demand_with_depot,
- "current_node": current_node,
- "used_capacity": used_capacity,
- "vehicle_capacity": td["vehicle_capacity"],
- "reward": reward,
- "done": done,
- }
- },
- td.shape,
+ "demand_with_depot": demand_with_depot,
+ "current_node": current_node,
+ "used_capacity": used_capacity,
+ "reward": reward,
+ "done": done,
+ }
)
- td_step["next"].set("action_mask", self.get_action_mask(td_step["next"]))
- return td_step
+ td.set("action_mask", self.get_action_mask(td))
+ return td
def _reset(
self,
@@ -117,7 +115,7 @@ def _reset(
if td is None or td.is_empty():
td = self.generate_data(batch_size=batch_size)
- self.device = td["locs"].device
+ self.to(td.device)
# Create reset TensorDict
reset_td = TensorDict(
@@ -211,7 +209,6 @@ def _make_spec(self, td_params: TensorDict):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/spctsp.py b/rl4co/envs/routing/spctsp.py
similarity index 95%
rename from rl4co/envs/spctsp.py
rename to rl4co/envs/routing/spctsp.py
index e6bfb757..e39a31b1 100644
--- a/rl4co/envs/spctsp.py
+++ b/rl4co/envs/routing/spctsp.py
@@ -1,6 +1,7 @@
-from rl4co.envs.pctsp import PCTSPEnv
from rl4co.utils.pylogger import get_pylogger
+from .pctsp import PCTSPEnv
+
log = get_pylogger(__name__)
diff --git a/rl4co/envs/tsp.py b/rl4co/envs/routing/tsp.py
similarity index 86%
rename from rl4co/envs/tsp.py
rename to rl4co/envs/routing/tsp.py
index 3d7d53a6..7db63541 100644
--- a/rl4co/envs/tsp.py
+++ b/rl4co/envs/routing/tsp.py
@@ -11,7 +11,6 @@
)
from rl4co.envs.common.base import RL4COEnvBase
-from rl4co.envs.common.utils import batch_to_scalar
from rl4co.utils.ops import gather_by_index, get_tour_length
from rl4co.utils.pylogger import get_pylogger
@@ -21,7 +20,7 @@
class TSPEnv(RL4COEnvBase):
"""
Traveling Salesman Problem environment
- At each step, the agent chooses a city to visit. The reward is the -infinite unless the agent visits all the cities.
+ At each step, the agent chooses a city to visit. The reward is 0 unless the agent visits all the cities.
In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length.
Args:
@@ -50,41 +49,38 @@ def __init__(
@staticmethod
def _step(td: TensorDict) -> TensorDict:
current_node = td["action"]
- first_node = current_node if batch_to_scalar(td["i"]) == 0 else td["first_node"]
+ first_node = current_node if td["i"].all() == 0 else td["first_node"]
- # Set not visited to 0 (i.e., we visited the node)
+ # # Set not visited to 0 (i.e., we visited the node)
available = td["action_mask"].scatter(
-1, current_node.unsqueeze(-1).expand_as(td["action_mask"]), 0
)
# We are done there are no unvisited locations
- done = torch.count_nonzero(available, dim=-1) <= 0
+ done = torch.sum(available, dim=-1) == 0
- # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here
- reward = torch.ones_like(done) * float("-inf")
+ # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
+ reward = torch.zeros_like(done)
- # The output must be written in a ``"next"`` entry
- return TensorDict(
+ td.update(
{
- "next": {
- "locs": td["locs"],
- "first_node": first_node,
- "current_node": current_node,
- "i": td["i"] + 1,
- "action_mask": available,
- "reward": reward,
- "done": done,
- }
+ "first_node": first_node,
+ "current_node": current_node,
+ "i": td["i"] + 1,
+ "action_mask": available,
+ "reward": reward,
+ "done": done,
},
- td.shape,
)
+ return td
def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
# Initialize locations
init_locs = td["locs"] if td is not None else None
if batch_size is None:
batch_size = self.batch_size if init_locs is None else init_locs.shape[:-2]
- self.device = device = init_locs.device if init_locs is not None else self.device
+ device = init_locs.device if init_locs is not None else self.device
+ self.to(device)
if init_locs is None:
init_locs = self.generate_data(batch_size=batch_size).to(device)["locs"]
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
@@ -106,6 +102,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict
"current_node": current_node,
"i": i,
"action_mask": available,
+ "reward": torch.zeros((*batch_size, 1), dtype=torch.float32),
},
batch_size=batch_size,
)
@@ -137,7 +134,6 @@ def _make_spec(self, td_params):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/scheduling/__init__.py b/rl4co/envs/scheduling/__init__.py
new file mode 100644
index 00000000..a9c5144b
--- /dev/null
+++ b/rl4co/envs/scheduling/__init__.py
@@ -0,0 +1,2 @@
+from rl4co.envs.scheduling.ffsp import FFSPEnv
+from rl4co.envs.scheduling.smtwtp import SMTWTPEnv
diff --git a/rl4co/envs/ffsp.py b/rl4co/envs/scheduling/ffsp.py
similarity index 92%
rename from rl4co/envs/ffsp.py
rename to rl4co/envs/scheduling/ffsp.py
index ff97d5d7..fb1dd5fe 100644
--- a/rl4co/envs/ffsp.py
+++ b/rl4co/envs/scheduling/ffsp.py
@@ -139,30 +139,25 @@ def _step(self, td: TensorDict) -> TensorDict:
reward = td["reward"]
- return TensorDict(
+ # Updated state
+ td.update(
{
- "next": {
- "stage_table": td["stage_table"],
- "machine_table": td["machine_table"],
- "time_idx": time_idx,
- "sub_time_idx": sub_time_idx,
- "batch_idx": batch_idx,
- "machine_idx": machine_idx,
- "schedule": schedule,
- "machine_wait_step": machine_wait_step,
- "job_location": job_location,
- "job_wait_step": job_wait_step,
- "job_duration": td["job_duration"],
- "reward": reward,
- "finish": finish,
- # Update variables
- "job_mask": job_mask,
- "stage_idx": stage_idx,
- "stage_machine_idx": stage_machine_idx,
- }
- },
- td.shape,
+ "time_idx": time_idx,
+ "sub_time_idx": sub_time_idx,
+ "batch_idx": batch_idx,
+ "machine_idx": machine_idx,
+ "schedule": schedule,
+ "machine_wait_step": machine_wait_step,
+ "job_location": job_location,
+ "job_wait_step": job_wait_step,
+ "reward": reward,
+ "finish": finish,
+ "job_mask": job_mask,
+ "stage_idx": stage_idx,
+ "stage_machine_idx": stage_machine_idx,
+ }
)
+ return td
def _reset(
self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None
@@ -321,7 +316,6 @@ def _make_spec(self, td_params: TensorDict):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/envs/smtwtp.py b/rl4co/envs/scheduling/smtwtp.py
similarity index 90%
rename from rl4co/envs/smtwtp.py
rename to rl4co/envs/scheduling/smtwtp.py
index 41f5e5af..d65d8d07 100644
--- a/rl4co/envs/smtwtp.py
+++ b/rl4co/envs/scheduling/smtwtp.py
@@ -22,7 +22,7 @@ class SMTWTPEnv(RL4COEnvBase):
SMTWTP is a scheduling problem in which a set of jobs must be processed on a single machine.
Each job i has a processing time, a weight, and a due date. The objective is to minimize the sum of the weighted tardiness of all jobs,
where the weighted tardiness of a job is defined as the product of its weight and the duration by which its completion time exceeds its due date.
- At each step, the agent chooses a job to process. The reward is the -infinite unless the agent processes all the jobs.
+ At each step, the agent chooses a job to process. The reward is 0 unless the agent processes all the jobs.
In that case, the reward is (-)objective value of the processing order: maximizing the reward is equivalent to minimizing the objective.
Args:
@@ -80,25 +80,19 @@ def _step(td: TensorDict) -> TensorDict:
# We are done there are no unvisited locations
done = torch.count_nonzero(available, dim=-1) <= 0
- # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here
- reward = torch.ones_like(done) * float("-inf")
+ # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
+ reward = torch.zeros_like(done)
- # The output must be written in a ``"next"`` entry
- return TensorDict(
+ td.update(
{
- "next": {
- "job_due_time": td["job_due_time"],
- "job_weight": td["job_weight"],
- "job_process_time": td["job_process_time"],
- "current_job": current_job,
- "current_time": current_time,
- "action_mask": available,
- "reward": reward,
- "done": done,
- }
- },
- td.shape,
+ "current_job": current_job,
+ "current_time": current_time,
+ "action_mask": available,
+ "reward": reward,
+ "done": done,
+ }
)
+ return td
def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
# Initialization
@@ -106,9 +100,8 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict
batch_size = self.batch_size if td is None else td["job_due_time"].shape[:-1]
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
- self.device = device = (
- td["job_due_time"].device if td is not None else self.device
- )
+ device = td["job_due_time"].device if td is not None else self.device
+ self.to(device)
td = self.generate_data(batch_size) if td is None else td
@@ -170,7 +163,6 @@ def _make_spec(self, td_params: TensorDict = None):
),
shape=(),
)
- self.input_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
diff --git a/rl4co/models/__init__.py b/rl4co/models/__init__.py
index c923dbe7..b4b794db 100644
--- a/rl4co/models/__init__.py
+++ b/rl4co/models/__init__.py
@@ -1,7 +1,6 @@
from rl4co.models.zoo.active_search import ActiveSearch
from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
-
from rl4co.models.zoo.common.search import SearchBase
from rl4co.models.zoo.eas import EAS, EASEmb, EASLay
from rl4co.models.zoo.ham import (
diff --git a/rl4co/models/nn/utils.py b/rl4co/models/nn/utils.py
index acdc622e..6ab21c7c 100644
--- a/rl4co/models/nn/utils.py
+++ b/rl4co/models/nn/utils.py
@@ -19,7 +19,9 @@ def get_log_likelihood(log_p, actions, mask, return_sum: bool = True):
if mask is not None:
log_p[~mask] = 0
- assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"
+ assert (
+ log_p > -1000
+ ).data.all(), "Logprobs should not be -inf, check sampling procedure!"
# Calculate log_likelihood
if return_sum:
diff --git a/rl4co/models/rl/common/base.py b/rl4co/models/rl/common/base.py
index 3ec0ce22..7a7d5b20 100644
--- a/rl4co/models/rl/common/base.py
+++ b/rl4co/models/rl/common/base.py
@@ -1,5 +1,5 @@
from functools import partial
-from typing import Any, Union, Iterable
+from typing import Any, Iterable, Union
import torch
import torch.nn as nn
@@ -149,7 +149,6 @@ def setup(self, stage="fit"):
self.data_cfg["test_data_size"], phase="test"
)
self.dataloader_names = None
-
self.setup_loggers()
self.post_setup_hook()
@@ -214,12 +213,16 @@ def configure_optimizers(self, parameters=None):
def log_metrics(self, metric_dict: dict, phase: str, dataloader_idx: int = None):
"""Log metrics to logger and progress bar"""
- metrics = getattr(self, f"{phase}_metrics")
+ metrics = getattr(self, f"{phase}_metrics")
dataloader_name = ""
if dataloader_idx is not None and self.dataloader_names is not None:
- dataloader_name = "/" + self.dataloader_names[dataloader_idx]
+ dataloader_name = "/" + self.dataloader_names[dataloader_idx]
metrics = {
- f"{phase}/{k}{dataloader_name}": v.mean() if isinstance(v, torch.Tensor) else v for k, v in metric_dict.items() if k in metrics
+ f"{phase}/{k}{dataloader_name}": v.mean()
+ if isinstance(v, torch.Tensor)
+ else v
+ for k, v in metric_dict.items()
+ if k in metrics
}
log_on_step = self.log_on_step if phase == "train" else False
on_epoch = False if phase == "train" else True
@@ -292,7 +295,10 @@ def _dataloader(self, dataset, batch_size, shuffle=False):
self.dataloader_names = list(dataset.keys())
else:
self.dataloader_names = [f"{i}" for i in range(len(dataset))]
- return [self._dataloader_single(ds, batch_size, shuffle) for ds in dataset.values()]
+ return [
+ self._dataloader_single(ds, batch_size, shuffle)
+ for ds in dataset.values()
+ ]
else:
return self._dataloader_single(dataset, batch_size, shuffle)
diff --git a/rl4co/models/rl/ppo/ppo.py b/rl4co/models/rl/ppo/ppo.py
index 6f3cc950..af7915ce 100644
--- a/rl4co/models/rl/ppo/ppo.py
+++ b/rl4co/models/rl/ppo/ppo.py
@@ -6,6 +6,7 @@
from torch.utils.data import DataLoader
+from rl4co.data.dataset import TensorDictDataset, tensordict_collate_fn
from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.common.base import RL4COLitModule
from rl4co.utils.pylogger import get_pylogger
@@ -124,8 +125,8 @@ def shared_step(
):
# Evaluate old actions, log probabilities, and rewards
with torch.no_grad():
- td = self.env.reset(batch)
- out = self.policy(td, self.env, phase=phase, return_actions=True)
+ td = self.env.reset(batch) # note: clone needed for dataloader
+ out = self.policy(td.clone(), self.env, phase=phase, return_actions=True)
if phase == "train":
batch_size = out["actions"].shape[0]
@@ -146,12 +147,17 @@ def shared_step(
td.set("reward", out["reward"])
td.set("action", out["actions"])
+ dataset = TensorDictDataset(td)
dataloader = DataLoader(
- td, batch_size=mini_batch_size, shuffle=True, collate_fn=lambda x: x
+ dataset,
+ batch_size=mini_batch_size,
+ shuffle=True,
+ collate_fn=tensordict_collate_fn,
)
for _ in range(self.ppo_cfg["ppo_epochs"]): # PPO inner epoch, K
for sub_td in dataloader:
+ previous_reward = sub_td["reward"].view(-1, 1)
ll, entropy = self.policy.evaluate_action(
sub_td, action=sub_td["action"]
)
@@ -163,7 +169,7 @@ def shared_step(
# Compute the advantage
value_pred = self.critic(sub_td) # [batch, 1]
- adv = sub_td["reward"].view(-1, 1) - value_pred.detach()
+ adv = previous_reward - value_pred.detach()
# Normalize advantage
if self.ppo_cfg["normalize_adv"]:
@@ -181,7 +187,7 @@ def shared_step(
).mean()
# compute value function loss
- value_loss = F.huber_loss(value_pred, sub_td["reward"].view(-1, 1))
+ value_loss = F.huber_loss(value_pred, previous_reward)
# compute total loss
loss = (
diff --git a/rl4co/models/rl/reinforce/baselines.py b/rl4co/models/rl/reinforce/baselines.py
index 8117edd1..8d345660 100644
--- a/rl4co/models/rl/reinforce/baselines.py
+++ b/rl4co/models/rl/reinforce/baselines.py
@@ -6,10 +6,9 @@
from scipy.stats import ttest_rel
from torch.utils.data import DataLoader
-from tqdm.auto import tqdm
from rl4co import utils
-from rl4co.data.dataset import ExtraKeyDataset, tensordict_collate_fn
+from rl4co.data.dataset import tensordict_collate_fn
from rl4co.models.rl.common.critic import CriticNetwork
log = utils.get_pylogger(__name__)
@@ -81,6 +80,7 @@ def eval(self, td, reward, env=None):
class MeanBaseline(REINFORCEBaseline):
"""Mean baseline: return mean of reward as baseline"""
+
def __new__(cls, **kw):
return ExponentialBaseline(beta=0.0, **kw)
@@ -158,13 +158,11 @@ class RolloutBaseline(REINFORCEBaseline):
Args:
bl_alpha: Alpha value for the baseline T-test
- progress_bar: Whether to show progress bar for rollout
"""
- def __init__(self, bl_alpha=0.05, progress_bar=False, **kw):
+ def __init__(self, bl_alpha=0.05, **kw):
super(RolloutBaseline, self).__init__()
self.bl_alpha = bl_alpha
- self.progress_bar = progress_bar
def setup(self, *args, **kw):
self._update_model(*args, **kw)
@@ -235,9 +233,7 @@ def eval_model(batch):
dl = DataLoader(dataset, batch_size=batch_size, collate_fn=tensordict_collate_fn)
- rewards = torch.cat(
- [eval_model(batch) for batch in tqdm(dl, disable=not self.progress_bar)], 0
- )
+ rewards = torch.cat([eval_model(batch) for batch in dl], 0)
return rewards
def wrap_dataset(self, dataset, env, batch_size=64, device="cpu", **kw):
@@ -253,7 +249,7 @@ def wrap_dataset(self, dataset, env, batch_size=64, device="cpu", **kw):
.detach()
.cpu()
)
- return ExtraKeyDataset(dataset, rewards)
+ return dataset.add_key("extra", rewards)
def __getstate__(self):
"""Do not include datasets in state to avoid pickling issues"""
diff --git a/rl4co/models/zoo/__init__.py b/rl4co/models/zoo/__init__.py
index c923dbe7..b4b794db 100644
--- a/rl4co/models/zoo/__init__.py
+++ b/rl4co/models/zoo/__init__.py
@@ -1,7 +1,6 @@
from rl4co.models.zoo.active_search import ActiveSearch
from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
-
from rl4co.models.zoo.common.search import SearchBase
from rl4co.models.zoo.eas import EAS, EASEmb, EASLay
from rl4co.models.zoo.ham import (
diff --git a/rl4co/models/zoo/ham/model.py b/rl4co/models/zoo/ham/model.py
index a95f558b..aa402e84 100644
--- a/rl4co/models/zoo/ham/model.py
+++ b/rl4co/models/zoo/ham/model.py
@@ -7,7 +7,7 @@
class HeterogeneousAttentionModel(REINFORCE):
- """Heterogenous Attention Model for solving the Pickup and Delivery Problem based on
+ """Heterogenous Attention Model for solving the Pickup and Delivery Problem based on
REINFORCE: https://arxiv.org/abs/2110.02634.
Args:
@@ -20,7 +20,7 @@ class HeterogeneousAttentionModel(REINFORCE):
"""
def __init__(
- self,
+ self,
env: RL4COEnvBase,
policy: HeterogeneousAttentionModelPolicy = None,
baseline: Union[REINFORCEBaseline, str] = "rollout",
diff --git a/rl4co/models/zoo/ham/policy.py b/rl4co/models/zoo/ham/policy.py
index d2ae43c8..a9cd4ea1 100644
--- a/rl4co/models/zoo/ham/policy.py
+++ b/rl4co/models/zoo/ham/policy.py
@@ -1,4 +1,3 @@
-import torch.nn as nn
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
from rl4co.models.zoo.ham.encoder import GraphHeterogeneousAttentionEncoder
@@ -41,4 +40,4 @@ def __init__(
num_heads=num_heads,
normalization=normalization,
**kwargs,
- )
\ No newline at end of file
+ )
diff --git a/rl4co/models/zoo/matnet/__init__.py b/rl4co/models/zoo/matnet/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/rl4co/models/zoo/matnet/decoder.py b/rl4co/models/zoo/matnet/decoder.py
new file mode 100644
index 00000000..e703bf5c
--- /dev/null
+++ b/rl4co/models/zoo/matnet/decoder.py
@@ -0,0 +1,52 @@
+from dataclasses import dataclass
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from rl4co.models.zoo.common.autoregressive.decoder import AutoregressiveDecoder
+from rl4co.utils.ops import batchify, get_num_starts, select_start_nodes, unbatchify
+from tensordict import TensorDict
+from torch import Tensor
+
+
+@dataclass
+class PrecomputedCache:
+ node_embeddings: Tensor
+ graph_context: Union[Tensor, float]
+ glimpse_key: Tensor
+ glimpse_val: Tensor
+ logit_key: Tensor
+
+
+class MatNetDecoder(AutoregressiveDecoder):
+ def _precompute_cache(
+ self, embeddings: Tuple[Tensor, Tensor], num_starts: int = 0, td: TensorDict = None
+ ):
+ col_emb, row_emb = embeddings
+ (
+ glimpse_key_fixed,
+ glimpse_val_fixed,
+ logit_key,
+ ) = self.project_node_embeddings(
+ col_emb
+ ).chunk(3, dim=-1)
+
+ # Optionally disable the graph context from the initial embedding as done in POMO
+ if self.use_graph_context:
+ graph_context = unbatchify(
+ batchify(self.project_fixed_context(col_emb.mean(1)), num_starts),
+ num_starts,
+ )
+ else:
+ graph_context = 0
+
+ # Organize in a dataclass for easy access
+ return PrecomputedCache(
+ node_embeddings=row_emb,
+ graph_context=graph_context,
+ glimpse_key=glimpse_key_fixed,
+ glimpse_val=glimpse_val_fixed,
+ # logit_key=col_emb,
+ logit_key=logit_key,
+ )
\ No newline at end of file
diff --git a/rl4co/models/zoo/matnet/encoder.py b/rl4co/models/zoo/matnet/encoder.py
new file mode 100644
index 00000000..273baa31
--- /dev/null
+++ b/rl4co/models/zoo/matnet/encoder.py
@@ -0,0 +1,309 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from rl4co.models.nn.ops import Normalization
+from tensordict import TensorDict
+
+
+class MatNetCrossMHA(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ bias: bool = True,
+ mixer_hidden_dim: int = 16,
+ mix1_init: float = (1 / 2) ** (1 / 2),
+ mix2_init: float = (1 / 16) ** (1 / 2),
+ ):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ assert (
+ self.embedding_dim % num_heads == 0
+ ), "embedding_dim must be divisible by num_heads"
+ self.head_dim = self.embedding_dim // num_heads
+
+ self.Wq = nn.Linear(embedding_dim, embedding_dim, bias=bias)
+ self.Wkv = nn.Linear(embedding_dim, 2 * embedding_dim, bias=bias)
+
+ # Score mixer
+ # Taken from the official MatNet implementation
+ # https://github.com/yd-kwon/MatNet/blob/main/ATSP/ATSP_MatNet/ATSPModel_LIB.py#L72
+ mix_W1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample(
+ (num_heads, 2, mixer_hidden_dim)
+ )
+ mix_b1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample(
+ (num_heads, mixer_hidden_dim)
+ )
+ self.mix_W1 = nn.Parameter(mix_W1)
+ self.mix_b1 = nn.Parameter(mix_b1)
+
+ mix_W2 = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample(
+ (num_heads, mixer_hidden_dim, 1)
+ )
+ mix_b2 = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample(
+ (num_heads, 1)
+ )
+ self.mix_W2 = nn.Parameter(mix_W2)
+ self.mix_b2 = nn.Parameter(mix_b2)
+
+ self.out_proj = nn.Linear(embedding_dim, embedding_dim, bias=bias)
+
+ def forward(self, q_input, kv_input, dmat):
+ """
+
+ Args:
+ q_input (Tensor): [b, m, d]
+ kv_input (Tensor): [b, n, d]
+ dmat (Tensor): [b, m, n]
+
+ Returns:
+ Tensor: [b, m, d]
+ """
+
+ b, m, n = dmat.shape
+
+ q = rearrange(
+ self.Wq(q_input), "b m (h d) -> b h m d", h=self.num_heads
+ ) # [b, h, m, d]
+ k, v = rearrange(
+ self.Wkv(kv_input), "b n (two h d) -> two b h n d", two=2, h=self.num_heads
+ ).unbind(
+ dim=0
+ ) # [b, h, n, d]
+
+ scale = math.sqrt(q.size(-1)) # scale factor
+ attn_scores = torch.matmul(q, k.transpose(2, 3)) / scale # [b, h, m, n]
+ mix_attn_scores = torch.stack(
+ [attn_scores, dmat[:, None, :, :].expand(b, self.num_heads, m, n)], dim=-1
+ ) # [b, h, m, n, 2]
+
+ mix_attn_scores = (
+ (
+ torch.matmul(
+ F.relu(
+ torch.matmul(mix_attn_scores.transpose(1, 2), self.mix_W1)
+ + self.mix_b1[None, None, :, None, :]
+ ),
+ self.mix_W2,
+ )
+ + self.mix_b2[None, None, :, None, :]
+ )
+ .transpose(1, 2)
+ .squeeze(-1)
+ ) # [b, h, m, n]
+
+ attn_probs = F.softmax(mix_attn_scores, dim=-1)
+ out = torch.matmul(attn_probs, v)
+ return self.out_proj(rearrange(out, "b h s d -> b s (h d)"))
+
+
+class MatNetMHA(nn.Module):
+ def __init__(self, embedding_dim: int, num_heads: int, bias: bool = True):
+ super().__init__()
+ self.row_encoding_block = MatNetCrossMHA(embedding_dim, num_heads, bias)
+ self.col_encoding_block = MatNetCrossMHA(embedding_dim, num_heads, bias)
+
+ def forward(self, row_emb, col_emb, dmat):
+ """
+ Args:
+ row_emb (Tensor): [b, m, d]
+ col_emb (Tensor): [b, n, d]
+ dmat (Tensor): [b, m, n]
+
+ Returns:
+ Updated row_emb (Tensor): [b, m, d]
+ Updated col_emb (Tensor): [b, n, d]
+ """
+
+ updated_row_emb = self.row_encoding_block(row_emb, col_emb, dmat)
+ updated_col_emb = self.col_encoding_block(
+ col_emb, row_emb, dmat.transpose(-2, -1)
+ )
+ return updated_row_emb, updated_col_emb
+
+
+class MatNetMHALayer(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ bias: bool = True,
+ feed_forward_hidden: int = 512,
+ normalization: Optional[str] = "instance",
+ ):
+ super().__init__()
+ self.MHA = MatNetMHA(embedding_dim, num_heads, bias)
+
+ self.F_a = nn.ModuleDict(
+ {
+ "norm1": Normalization(embedding_dim, normalization),
+ "ffn": nn.Sequential(
+ nn.Linear(embedding_dim, feed_forward_hidden),
+ nn.ReLU(),
+ nn.Linear(feed_forward_hidden, embedding_dim),
+ ),
+ "norm2": Normalization(embedding_dim, normalization),
+ }
+ )
+
+ self.F_b = nn.ModuleDict(
+ {
+ "norm1": Normalization(embedding_dim, normalization),
+ "ffn": nn.Sequential(
+ nn.Linear(embedding_dim, feed_forward_hidden),
+ nn.ReLU(),
+ nn.Linear(feed_forward_hidden, embedding_dim),
+ ),
+ "norm2": Normalization(embedding_dim, normalization),
+ }
+ )
+
+ def forward(self, row_emb, col_emb, dmat):
+ """
+ Args:
+ row_emb (Tensor): [b, m, d]
+ col_emb (Tensor): [b, n, d]
+ dmat (Tensor): [b, m, n]
+
+ Returns:
+ Updated row_emb (Tensor): [b, m, d]
+ Updated col_emb (Tensor): [b, n, d]
+ """
+
+ row_emb_out, col_emb_out = self.MHA(row_emb, col_emb, dmat)
+
+ row_emb_out = self.F_a["norm1"](row_emb + row_emb_out)
+ row_emb_out = self.F_a["norm2"](row_emb_out + self.F_a["ffn"](row_emb_out))
+
+ col_emb_out = self.F_b["norm1"](col_emb + col_emb_out)
+ col_emb_out = self.F_b["norm2"](col_emb_out + self.F_b["ffn"](col_emb_out))
+ return row_emb_out, col_emb_out
+
+
+class MatNetMHANetwork(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ num_heads: int = 8,
+ num_layers: int = 3,
+ normalization: str = "batch",
+ feed_forward_hidden: int = 512,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [
+ MatNetMHALayer(
+ num_heads=num_heads,
+ embedding_dim=embedding_dim,
+ feed_forward_hidden=feed_forward_hidden,
+ normalization=normalization,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(self, row_emb, col_emb, dmat):
+ """
+ Args:
+ row_emb (Tensor): [b, m, d]
+ col_emb (Tensor): [b, n, d]
+ dmat (Tensor): [b, m, n]
+
+ Returns:
+ Updated row_emb (Tensor): [b, m, d]
+ Updated col_emb (Tensor): [b, n, d]
+ """
+
+ for layer in self.layers:
+ row_emb, col_emb = layer(row_emb, col_emb, dmat)
+ return row_emb, col_emb
+
+
+class MatNetATSPInitEmbedding(nn.Module):
+ """
+ Preparing the initial row and column embeddings for ATSP.
+
+ Reference:
+ https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L51
+
+
+ """
+
+ def __init__(self, embedding_dim: int, mode: str = "RandomOneHot") -> None:
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ assert mode in {
+ "RandomOneHot",
+ "Random",
+ }, "mode must be one of ['RandomOneHot', 'Random']"
+ self.mode = mode
+
+ self.dmat_proj = nn.Linear(1, 2 * embedding_dim, bias=False)
+ self.row_proj = nn.Linear(embedding_dim * 4, embedding_dim, bias=False)
+ self.col_proj = nn.Linear(embedding_dim * 4, embedding_dim, bias=False)
+
+ def forward(self, td: TensorDict):
+ dmat = td["cost_matrix"] # [b, n, n]
+ b, n, _ = dmat.shape
+
+ row_emb = torch.zeros(b, n, self.embedding_dim, device=dmat.device)
+
+ if self.mode == "RandomOneHot":
+ # MatNet uses one-hot encoding for column embeddings
+ # https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L60
+
+ col_emb = torch.zeros(b, n, self.embedding_dim, device=dmat.device)
+ rand = torch.rand(b, n)
+ rand_idx = rand.argsort(dim=1)
+ b_idx = torch.arange(b)[:, None].expand(b, n)
+ n_idx = torch.arange(n)[None, :].expand(b, n)
+ col_emb[b_idx, n_idx, rand_idx] = 1.0
+
+ elif self.mode == "Random":
+ col_emb = torch.rand(b, n, self.embedding_dim, device=dmat.device)
+ else:
+ raise NotImplementedError
+
+ return row_emb, col_emb, dmat
+
+
+class MatNetEncoder(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int = 256,
+ num_heads: int = 16,
+ num_layers: int = 5,
+ normalization: str = "instance",
+ feed_forward_hidden: int = 512,
+ init_embedding: nn.Module = None,
+ init_embedding_kwargs: dict = None,
+ ):
+ super().__init__()
+
+ if init_embedding is None:
+ init_embedding = MatNetATSPInitEmbedding(
+ embedding_dim, **init_embedding_kwargs
+ )
+
+ self.init_embedding = init_embedding
+ self.net = MatNetMHANetwork(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ num_layers=num_layers,
+ normalization=normalization,
+ feed_forward_hidden=feed_forward_hidden,
+ )
+
+ def forward(self, td):
+ row_emb, col_emb, dmat = self.init_embedding(td)
+ row_emb, col_emb = self.net(row_emb, col_emb, dmat)
+
+ embedding = (row_emb, col_emb)
+ init_embedding = None
+ return embedding, init_embedding # match output signature for the AR policy class
diff --git a/rl4co/models/zoo/matnet/model.py b/rl4co/models/zoo/matnet/model.py
new file mode 100644
index 00000000..1af9cace
--- /dev/null
+++ b/rl4co/models/zoo/matnet/model.py
@@ -0,0 +1,39 @@
+from typing import Any, Union
+from rl4co.models.zoo.matnet.policy import MatNetPolicy
+
+import torch.nn as nn
+
+from rl4co.models.zoo.pomo.model import POMO
+from rl4co.envs.common.base import RL4COEnvBase
+
+
+class MatNet(POMO):
+ def __init__(
+ self,
+ env: RL4COEnvBase,
+ policy: Union[nn.Module, MatNetPolicy] = None,
+ optimizer_kwargs: dict = {"lr": 4 * 1e-4, "weight_decay": 1e-6},
+ lr_scheduler: str = "MultiStepLR",
+ lr_scheduler_kwargs: dict = {"milestones": [2001, 2101], "gamma": 0.1},
+ use_dihedral_8: bool = False,
+ num_starts: int = None,
+ train_data_size: int = 10_000,
+ batch_size: int = 200,
+ policy_params: dict = {},
+ model_params: dict = {},
+ ):
+ if policy is None:
+ policy = MatNetPolicy(env_name=env.name, **policy_params)
+
+ super(MatNet, self).__init__(
+ env=env,
+ policy=policy,
+ optimizer_kwargs=optimizer_kwargs,
+ lr_scheduler=lr_scheduler,
+ lr_scheduler_kwargs=lr_scheduler_kwargs,
+ use_dihedral_8=use_dihedral_8,
+ num_starts=num_starts,
+ train_data_size=train_data_size,
+ batch_size=batch_size,
+ **model_params,
+ )
diff --git a/rl4co/models/zoo/matnet/policy.py b/rl4co/models/zoo/matnet/policy.py
new file mode 100644
index 00000000..8b4e1761
--- /dev/null
+++ b/rl4co/models/zoo/matnet/policy.py
@@ -0,0 +1,61 @@
+from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
+from rl4co.models.zoo.matnet.encoder import MatNetEncoder
+from rl4co.models.zoo.matnet.decoder import MatNetDecoder
+from rl4co.utils.pylogger import get_pylogger
+
+log = get_pylogger(__name__)
+
+
+class MatNetPolicy(AutoregressivePolicy):
+ """MatNet Policy from Kwon et al., 2021.
+ Reference: https://arxiv.org/abs/2106.11113
+
+ Warning:
+ This implementation is under development and subject to change.
+
+ Args:
+ env_name: Name of the environment used to initialize embeddings
+ embedding_dim: Dimension of the node embeddings
+ num_encoder_layers: Number of layers in the encoder
+ num_heads: Number of heads in the attention layers
+ normalization: Normalization type in the attention layers
+ **kwargs: keyword arguments passed to the `AutoregressivePolicy`
+
+ Default paarameters are adopted from the original implementation.
+ """
+
+ def __init__(
+ self,
+ env_name: str,
+ embedding_dim: int = 256,
+ num_encoder_layers: int = 5,
+ num_heads: int = 16,
+ normalization: str = "instance",
+ init_embedding_kwargs: dict = {"mode": "RandomOneHot"},
+ use_graph_context: bool = False,
+ **kwargs,
+ ):
+ if env_name not in ["atsp"]:
+ log.error(f"env_name {env_name} is not originally implemented in MatNet")
+
+ super(MatNetPolicy, self).__init__(
+ env_name=env_name,
+ encoder=MatNetEncoder(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ num_layers=num_encoder_layers,
+ normalization=normalization,
+ init_embedding_kwargs=init_embedding_kwargs,
+ ),
+ decoder=MatNetDecoder(
+ env_name=env_name,
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ use_graph_context=use_graph_context,
+ ),
+ embedding_dim=embedding_dim,
+ num_encoder_layers=num_encoder_layers,
+ num_heads=num_heads,
+ normalization=normalization,
+ **kwargs,
+ )
diff --git a/rl4co/models/zoo/mdam/__init__.py b/rl4co/models/zoo/mdam/__init__.py
index 0dcc6521..2b7a14da 100644
--- a/rl4co/models/zoo/mdam/__init__.py
+++ b/rl4co/models/zoo/mdam/__init__.py
@@ -1,2 +1,2 @@
+from .model import MDAM
from .policy import MDAMPolicy
-from .model import MDAM
\ No newline at end of file
diff --git a/rl4co/models/zoo/mdam/decoder.py b/rl4co/models/zoo/mdam/decoder.py
index 87fd0dee..8e1b9daf 100644
--- a/rl4co/models/zoo/mdam/decoder.py
+++ b/rl4co/models/zoo/mdam/decoder.py
@@ -1,15 +1,14 @@
import math
-from typing import Union
from dataclasses import dataclass
-from tensordict import TensorDict
+from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
+from tensordict import TensorDict
from rl4co.envs import RL4COEnvBase
-
from rl4co.models.nn.attention import LogitAttention
from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding
from rl4co.models.nn.utils import decode_probs, get_log_likelihood
@@ -67,8 +66,7 @@ def __init__(
self.project_node_embeddings = nn.ModuleList(self.project_node_embeddings)
self.project_fixed_context = [
- nn.Linear(embedding_dim, embedding_dim, bias=False)
- for _ in range(num_paths)
+ nn.Linear(embedding_dim, embedding_dim, bias=False) for _ in range(num_paths)
]
self.project_fixed_context = nn.ModuleList(self.project_fixed_context)
@@ -79,8 +77,7 @@ def __init__(
self.project_step_context = nn.ModuleList(self.project_step_context)
self.project_out = [
- nn.Linear(embedding_dim, embedding_dim, bias=False)
- for _ in range(num_paths)
+ nn.Linear(embedding_dim, embedding_dim, bias=False) for _ in range(num_paths)
]
self.project_out = nn.ModuleList(self.project_out)
@@ -108,15 +105,15 @@ def __init__(
self.shrink_size = shrink_size
def forward(
- self,
- td: TensorDict,
- encoded_inputs: torch.Tensor,
- env: Union[str, RL4COEnvBase],
- attn,
- V,
- h_old,
- **decoder_kwargs
- ):
+ self,
+ td: TensorDict,
+ encoded_inputs: torch.Tensor,
+ env: Union[str, RL4COEnvBase],
+ attn,
+ V,
+ h_old,
+ **decoder_kwargs,
+ ):
# SECTION: Decoder first step: calculate for the decoder divergence loss
# Cost list and log likelihood list along with path
output_list = []
@@ -261,7 +258,9 @@ def _get_log_p(self, fixed, td, path_index, normalize=True):
step_context = self.context[path_index](
fixed.node_embeddings, td
) # [batch, embed_dim]
- glimpse_q = fixed.graph_context + step_context.unsqueeze(1).to(fixed.graph_context.device)
+ glimpse_q = fixed.graph_context + step_context.unsqueeze(1).to(
+ fixed.graph_context.device
+ )
# Compute keys and values for the nodes
(
diff --git a/rl4co/models/zoo/mdam/model.py b/rl4co/models/zoo/mdam/model.py
index b0696cf3..5ab935a6 100644
--- a/rl4co/models/zoo/mdam/model.py
+++ b/rl4co/models/zoo/mdam/model.py
@@ -1,4 +1,3 @@
-
from typing import Union
from rl4co.envs.common.base import RL4COEnvBase
@@ -8,10 +7,10 @@
class MDAM(REINFORCE):
- """ Multi-Decoder Attention Model (MDAM) is a model
- to train multiple diverse policies, which effectively increases the chance of finding
+ """Multi-Decoder Attention Model (MDAM) is a model
+ to train multiple diverse policies, which effectively increases the chance of finding
good solutions compared with existing methods that train only one policy.
- Reference link: https://arxiv.org/abs/2012.10638;
+ Reference link: https://arxiv.org/abs/2012.10638;
Implementation reference: https://github.com/liangxinedu/MDAM.
Args:
@@ -24,15 +23,15 @@ class MDAM(REINFORCE):
"""
def __init__(
- self,
- env: RL4COEnvBase,
- policy: MDAMPolicy = None,
- baseline: Union[REINFORCEBaseline, str] = "rollout",
- policy_kwargs={},
- baseline_kwargs={},
- **kwargs
- ):
+ self,
+ env: RL4COEnvBase,
+ policy: MDAMPolicy = None,
+ baseline: Union[REINFORCEBaseline, str] = "rollout",
+ policy_kwargs={},
+ baseline_kwargs={},
+ **kwargs,
+ ):
if policy is None:
- policy = MDAMPolicy(env.name, **policy_kwargs)
+ policy = MDAMPolicy(env.name, **policy_kwargs)
- super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)
\ No newline at end of file
+ super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)
diff --git a/rl4co/models/zoo/mdam/policy.py b/rl4co/models/zoo/mdam/policy.py
index 7a8b7c04..30299dd5 100644
--- a/rl4co/models/zoo/mdam/policy.py
+++ b/rl4co/models/zoo/mdam/policy.py
@@ -1,26 +1,25 @@
-import torch.nn as nn
from typing import Union
from tensordict import TensorDict
-from rl4co.envs import RL4COEnvBase, get_env
+from rl4co.envs import RL4COEnvBase, get_env
from rl4co.models.nn.env_embeddings import env_init_embedding
+from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
from rl4co.models.zoo.mdam.decoder import Decoder
from rl4co.models.zoo.mdam.encoder import GraphAttentionEncoder
-from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
from rl4co.utils.pylogger import get_pylogger
log = get_pylogger(__name__)
class MDAMPolicy(AutoregressivePolicy):
- """ Multi-Decoder Attention Model (MDAM) policy.
+ """Multi-Decoder Attention Model (MDAM) policy.
Args:
"""
-
+
def __init__(
- self,
+ self,
env_name: str,
embedding_dim: int = 128,
num_encoder_layers: int = 3,
@@ -35,13 +34,13 @@ def __init__(
embed_dim=embedding_dim,
num_layers=num_encoder_layers,
normalization=normalization,
- **kwargs
+ **kwargs,
),
decoder=Decoder(
env_name=env_name,
embedding_dim=embedding_dim,
num_heads=num_heads,
- **kwargs
+ **kwargs,
),
embedding_dim=embedding_dim,
num_encoder_layers=num_encoder_layers,
@@ -84,4 +83,4 @@ def forward(
"entropy": kl_divergence,
"actions": actions if return_actions else None,
}
- return out
\ No newline at end of file
+ return out
diff --git a/rl4co/models/zoo/pomo/model.py b/rl4co/models/zoo/pomo/model.py
index eefff8ae..1d707df2 100644
--- a/rl4co/models/zoo/pomo/model.py
+++ b/rl4co/models/zoo/pomo/model.py
@@ -56,7 +56,9 @@ def __init__(
for phase in ["train", "val", "test"]:
self.set_decode_type_multistart(phase)
- def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None):
+ def shared_step(
+ self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
+ ):
td = self.env.reset(batch)
n_aug, n_start = self.num_augment, self.num_starts
n_start = get_num_starts(td) if n_start is None else n_start
@@ -102,10 +104,10 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: in
out.update({"max_aug_reward": max_aug_reward})
if out.get("actions", None) is not None:
- actions_ = out["best_multistart_actions"] if n_start > 1 else out["actions"]
- out.update(
- {"best_aug_actions": gather_by_index(actions_, max_idxs)}
+ actions_ = (
+ out["best_multistart_actions"] if n_start > 1 else out["actions"]
)
+ out.update({"best_aug_actions": gather_by_index(actions_, max_idxs)})
metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx)
return {"loss": out.get("loss", None), **metrics}
diff --git a/rl4co/tasks/train.py b/rl4co/tasks/train.py
index 8d04a01c..6b628470 100644
--- a/rl4co/tasks/train.py
+++ b/rl4co/tasks/train.py
@@ -9,11 +9,12 @@
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
-pyrootutils.setup_root(__file__, indicator=".gitignore", pythonpath=True)
-
from rl4co import utils
from rl4co.utils import RL4COTrainer
+pyrootutils.setup_root(__file__, indicator=".gitignore", pythonpath=True)
+
+
log = utils.get_pylogger(__name__)
diff --git a/rl4co/utils/ops.py b/rl4co/utils/ops.py
index 2034d0de..669e4a12 100644
--- a/rl4co/utils/ops.py
+++ b/rl4co/utils/ops.py
@@ -6,7 +6,6 @@
from torch import Tensor
-# @torch.jit.script
def _batchify_single(
x: Union[Tensor, TensorDict], repeats: int
) -> Union[Tensor, TensorDict]:
@@ -61,7 +60,6 @@ def unbatchify(
return x
-# @torch.jit.script
def gather_by_index(src, idx, dim=1, squeeze=True):
"""Gather elements from src by index idx along specified dim
@@ -76,13 +74,13 @@ def gather_by_index(src, idx, dim=1, squeeze=True):
return src.gather(dim, idx).squeeze() if squeeze else src.gather(dim, idx)
-# @torch.jit.script
+@torch.jit.script
def get_distance(x: Tensor, y: Tensor):
"""Euclidean distance between two tensors of shape `[..., n, dim]`"""
return (x - y).norm(p=2, dim=-1)
-# @torch.jit.script
+@torch.jit.script
def get_tour_length(ordered_locs):
"""Compute the total tour distance for a batch of ordered tours.
Computes the L2 norm between each pair of consecutive nodes in the tour and sums them up.
diff --git a/rl4co/utils/optim_helpers.py b/rl4co/utils/optim_helpers.py
index f784a62b..46367a37 100644
--- a/rl4co/utils/optim_helpers.py
+++ b/rl4co/utils/optim_helpers.py
@@ -1,7 +1,6 @@
import inspect
import torch
-import torch.nn as nn
from torch.optim import Optimizer
diff --git a/rl4co/utils/trainer.py b/rl4co/utils/trainer.py
index a76b4e5f..790437c3 100644
--- a/rl4co/utils/trainer.py
+++ b/rl4co/utils/trainer.py
@@ -68,7 +68,7 @@ def __init__(
except AttributeError:
pass
- # Configure DDP automatically
+ # Configure DDP automatically if multiple GPUs are available
if auto_configure_ddp and strategy == "auto":
if devices == "auto":
n_devices = num_cuda_devices()
@@ -77,7 +77,11 @@ def __init__(
else:
n_devices = devices
if n_devices > 1:
- log.info("Configuring DDP strategy automatically")
+ log.info(
+ "Configuring DDP strategy automatically with {} GPUs".format(
+ n_devices
+ )
+ )
strategy = DDPStrategy(
find_unused_parameters=True, # We set to True due to RL envs
gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations
@@ -89,7 +93,9 @@ def __init__(
# Check if gradient_clip_val is set to None
if gradient_clip_val is None:
- log.warning("gradient_clip_val is set to None. This may lead to unstable training.")
+ log.warning(
+ "gradient_clip_val is set to None. This may lead to unstable training."
+ )
# We should reload dataloaders every epoch for RL training
if reload_dataloaders_every_n_epochs != 1: