From ba61a64ab80484243a2b05496dde37d6b609ad9c Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:48:49 +0900 Subject: [PATCH 01/13] Refactoring; update to new TorchRL --- rl4co/envs/__init__.py | 25 ++++++-------- rl4co/envs/common/base.py | 44 ++++++++++++++++++++++++ rl4co/envs/eda/__init__.py | 2 ++ rl4co/envs/{ => eda}/dpp.py | 34 ++++++++----------- rl4co/envs/{ => eda}/mdpp.py | 11 +++--- rl4co/envs/routing/__init__.py | 9 +++++ rl4co/envs/{ => routing}/atsp.py | 30 ++++++++--------- rl4co/envs/{ => routing}/cvrp.py | 41 ++++++++++------------- rl4co/envs/{ => routing}/mpdp.py | 48 ++++++++++----------------- rl4co/envs/{ => routing}/mtsp.py | 32 ++++++++---------- rl4co/envs/{ => routing}/op.py | 40 +++++++++------------- rl4co/envs/{ => routing}/pctsp.py | 42 ++++++++++------------- rl4co/envs/{ => routing}/pdp.py | 32 ++++++++---------- rl4co/envs/{ => routing}/sdvrp.py | 38 ++++++++++----------- rl4co/envs/{ => routing}/spctsp.py | 2 +- rl4co/envs/{ => routing}/tsp.py | 37 ++++++++++----------- rl4co/envs/scheduling/__init__.py | 2 ++ rl4co/envs/{ => scheduling}/ffsp.py | 42 ++++++++++------------- rl4co/envs/{ => scheduling}/smtwtp.py | 32 ++++++++---------- 19 files changed, 265 insertions(+), 278 deletions(-) create mode 100644 rl4co/envs/eda/__init__.py rename rl4co/envs/{ => eda}/dpp.py (95%) rename rl4co/envs/{ => eda}/mdpp.py (98%) create mode 100644 rl4co/envs/routing/__init__.py rename rl4co/envs/{ => routing}/atsp.py (91%) rename rl4co/envs/{ => routing}/cvrp.py (93%) rename rl4co/envs/{ => routing}/mpdp.py (93%) rename rl4co/envs/{ => routing}/mtsp.py (94%) rename rl4co/envs/{ => routing}/op.py (94%) rename rl4co/envs/{ => routing}/pctsp.py (93%) rename rl4co/envs/{ => routing}/pdp.py (93%) rename rl4co/envs/{ => routing}/sdvrp.py (89%) rename rl4co/envs/{ => routing}/spctsp.py (95%) rename rl4co/envs/{ => routing}/tsp.py (87%) create mode 100644 rl4co/envs/scheduling/__init__.py rename rl4co/envs/{ => scheduling}/ffsp.py (92%) rename rl4co/envs/{ => scheduling}/smtwtp.py (91%) diff --git a/rl4co/envs/__init__.py b/rl4co/envs/__init__.py index 26eeecb0..eb9faef5 100644 --- a/rl4co/envs/__init__.py +++ b/rl4co/envs/__init__.py @@ -1,25 +1,22 @@ # Base environment -# Main Environments -from rl4co.envs.atsp import ATSPEnv from rl4co.envs.common.base import RL4COEnvBase -from rl4co.envs.cvrp import CVRPEnv -from rl4co.envs.dpp import DPPEnv -from rl4co.envs.ffsp import FFSPEnv -from rl4co.envs.mdpp import MDPPEnv -from rl4co.envs.mtsp import MTSPEnv -from rl4co.envs.op import OPEnv -from rl4co.envs.pctsp import PCTSPEnv -from rl4co.envs.pdp import PDPEnv -from rl4co.envs.sdvrp import SDVRPEnv -from rl4co.envs.smtwtp import SMTWTPEnv -from rl4co.envs.spctsp import SPCTSPEnv -from rl4co.envs.tsp import TSPEnv + +# EDA +from rl4co.envs.eda import DPPEnv, MDPPEnv + +# Routing +from rl4co.envs.routing import ATSPEnv, CVRPEnv, MTSPEnv, OPEnv, PCTSPEnv, PDPEnv, SDVRPEnv, SPCTSPEnv, TSPEnv + +# Scheduling +from rl4co.envs.scheduling import FFSPEnv, SMTWTPEnv + # Register environments ENV_REGISTRY = { "atsp": ATSPEnv, "cvrp": CVRPEnv, "dpp": DPPEnv, + "ffsp": FFSPEnv, "mdpp": MDPPEnv, "mtsp": MTSPEnv, "op": OPEnv, diff --git a/rl4co/envs/common/base.py b/rl4co/envs/common/base.py index 1607d413..b65774f0 100644 --- a/rl4co/envs/common/base.py +++ b/rl4co/envs/common/base.py @@ -40,6 +40,7 @@ def __init__( val_dataloader_names: list = None, test_dataloader_names: list = None, check_solution: bool = True, + _torchrl_mode: bool = False, # TODO seed: int = None, device: str = "cpu", **kwargs, @@ -47,6 +48,7 @@ def __init__( super().__init__(device=device, batch_size=[]) self.data_dir = data_dir self.train_file = pjoin(data_dir, train_file) if train_file is not None else None + self._torchrl_mode = _torchrl_mode def get_files(f): if f is not None: @@ -85,6 +87,41 @@ def get_multiple_dataloader_names(f, names): seed = torch.empty((), dtype=torch.int64).random_().item() self.set_seed(seed) + def step(self, td: TensorDict) -> TensorDict: + """Step function to call at each step of the episode containing an action. + If `_torchrl_mode` is True, we call `_torchrl_step` instead which set the + `next` key of the TensorDict to the next state - this is the usual way to do it in TorchRL, + but inefficient in our case + """ + if not self._torchrl_mode: + # Default: just return the TensorDict without farther checks etc is faster + td = self._step(td) + return {"next": td} + else: + # Since we simplify the syntax + return self._torchrl_step(td) + + def _torchrl_step(self, td: TensorDict) -> TensorDict: + """See :meth:`super().step` for more details. + This is the usual way to do it in TorchRL, but inefficient in our case + + Note: + Here we clone the TensorDict to avoid recursion error, since we allow + for directly updating the TensorDict in the step function + """ + # sanity check + self._assert_tensordict_shape(td) + next_preset = td.get("next", None) + + next_tensordict = self._step(td.clone()) # NOTE: we clone to avoid recursion error + next_tensordict = self._step_proc_data(next_tensordict) + if next_preset is not None: + next_tensordict.update( + next_preset.exclude(*next_tensordict.keys(True, True)) + ) + td.set("next", next_tensordict) + return td + def _step(self, td: TensorDict) -> TensorDict: """Step function to call at each step of the episode containing an action. Gives the next observation, reward, done @@ -177,6 +214,13 @@ def _set_seed(self, seed: Optional[int]): """Set the seed for the environment""" rng = torch.manual_seed(seed) self.rng = rng + + def to(self, device): + """Override `to` device method for safety against `None` device (may be found in `TensorDict`))""" + if device is None: + return self + else: + return super().to(device) def __getstate__(self): """Return the state of the environment. By default, we want to avoid pickling diff --git a/rl4co/envs/eda/__init__.py b/rl4co/envs/eda/__init__.py new file mode 100644 index 00000000..f306ecee --- /dev/null +++ b/rl4co/envs/eda/__init__.py @@ -0,0 +1,2 @@ +from rl4co.envs.eda.dpp import DPPEnv +from rl4co.envs.eda.mdpp import MDPPEnv \ No newline at end of file diff --git a/rl4co/envs/dpp.py b/rl4co/envs/eda/dpp.py similarity index 95% rename from rl4co/envs/dpp.py rename to rl4co/envs/eda/dpp.py index 8ac8a857..633af37f 100644 --- a/rl4co/envs/dpp.py +++ b/rl4co/envs/eda/dpp.py @@ -100,31 +100,26 @@ def _step(self, td: TensorDict) -> TensorDict: # Set done if i is greater than max_decaps done = td["i"] >= self.max_decaps - 1 - # Calculate reward (we set to -inf since we calculate the reward outside based on the actions) - reward = torch.ones_like(done) * float("-inf") - - # The output must be written in a ``"next"`` entry - return TensorDict( + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here + reward = torch.zeros_like(done) + + td.update( { - "next": { - "locs": td["locs"], - "probe": td["probe"], - "i": td["i"] + 1, - "action_mask": available, - "keepout": td["keepout"], - "reward": reward, - "done": done, - } - }, - td.shape, + "i": td["i"] + 1, + "action_mask": available, + "reward": reward, + "done": done, + } ) + return td def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: # Initialize locations if batch_size is None: batch_size = self.batch_size if td is None else td.batch_size - self.device = td.device if td is not None else self.device - + device = td.device if td is not None else self.device + self.to(device) + # We allow loading the initial observation from a dataset for faster loading if td is None: td = self.generate_data(batch_size=batch_size) @@ -170,7 +165,6 @@ def _make_spec(self, td_params): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, @@ -297,7 +291,7 @@ def _initial_impedance(self, probe): return zout def _decap_simulator(self, probe, solution, keepout=None): - self.device = solution.device + self.to(self.device) probe = probe.item() diff --git a/rl4co/envs/mdpp.py b/rl4co/envs/eda/mdpp.py similarity index 98% rename from rl4co/envs/mdpp.py rename to rl4co/envs/eda/mdpp.py index 633cd5e9..25948fe0 100644 --- a/rl4co/envs/mdpp.py +++ b/rl4co/envs/eda/mdpp.py @@ -11,7 +11,7 @@ UnboundedDiscreteTensorSpec, ) -from rl4co.envs.dpp import DPPEnv +from rl4co.envs.eda.dpp import DPPEnv from rl4co.utils.pylogger import get_pylogger log = get_pylogger(__name__) @@ -64,8 +64,12 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict # Action mask is 0 if both action_mask (e.g. keepout) and probe are 0 action_mask = torch.logical_and(td_reset["action_mask"], ~td_reset["probe"]) # Keepout regions are the inverse of action_mask - td_reset.set_("keepout", ~td_reset["action_mask"]) - td_reset.set_("action_mask", action_mask) + td_reset.update( + { + "keepout": ~td_reset["action_mask"], + "action_mask": action_mask, + } + ) return td_reset def _make_spec(self, td_params): @@ -95,7 +99,6 @@ def _make_spec(self, td_params): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/routing/__init__.py b/rl4co/envs/routing/__init__.py new file mode 100644 index 00000000..d7b815d5 --- /dev/null +++ b/rl4co/envs/routing/__init__.py @@ -0,0 +1,9 @@ +from rl4co.envs.routing.atsp import ATSPEnv +from rl4co.envs.routing.cvrp import CVRPEnv +from rl4co.envs.routing.mtsp import MTSPEnv +from rl4co.envs.routing.op import OPEnv +from rl4co.envs.routing.pctsp import PCTSPEnv +from rl4co.envs.routing.pdp import PDPEnv +from rl4co.envs.routing.sdvrp import SDVRPEnv +from rl4co.envs.routing.spctsp import SPCTSPEnv +from rl4co.envs.routing.tsp import TSPEnv diff --git a/rl4co/envs/atsp.py b/rl4co/envs/routing/atsp.py similarity index 91% rename from rl4co/envs/atsp.py rename to rl4co/envs/routing/atsp.py index fe451db9..e00a5cee 100644 --- a/rl4co/envs/atsp.py +++ b/rl4co/envs/routing/atsp.py @@ -20,7 +20,7 @@ class ATSPEnv(RL4COEnvBase): """ Asymmetric Traveling Salesman Problem environment - At each step, the agent chooses a city to visit. The reward is the -infinite unless the agent visits all the cities. + At each step, the agent chooses a city to visit. The reward is 0 unless the agent visits all the cities. In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. Unlike the TSP, the distance matrix is asymmetric, i.e., the distance from A to B is not necessarily the same as the distance from B to A. @@ -62,24 +62,20 @@ def _step(td: TensorDict) -> TensorDict: # We are done there are no unvisited locations done = torch.count_nonzero(available, dim=-1) <= 0 - # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here - reward = torch.ones_like(done) * float("-inf") + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here + reward = torch.zeros_like(done) - # The output must be written in a ``"next"`` entry - return TensorDict( + td.update( { - "next": { - "cost_matrix": td["cost_matrix"], - "first_node": first_node, - "current_node": current_node, - "i": td["i"] + 1, - "action_mask": available, - "reward": reward, - "done": done, - } + "first_node": first_node, + "current_node": current_node, + "i": td["i"] + 1, + "action_mask": available, + "reward": reward, + "done": done, }, - td.shape, ) + return td def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: # Initialize distance matrix @@ -90,9 +86,10 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size = ( self.batch_size if cost_matrix is None else cost_matrix.shape[:-2] ) - self.device = device = ( + device = ( cost_matrix.device if cost_matrix is not None else self.device ) + self.to(device) if cost_matrix is None: cost_matrix = self.generate_data(batch_size=batch_size).to(device)[ "cost_matrix" @@ -142,7 +139,6 @@ def _make_spec(self, td_params: TensorDict = None): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/cvrp.py b/rl4co/envs/routing/cvrp.py similarity index 93% rename from rl4co/envs/cvrp.py rename to rl4co/envs/routing/cvrp.py index 2cde5f51..14d43c20 100644 --- a/rl4co/envs/cvrp.py +++ b/rl4co/envs/routing/cvrp.py @@ -41,7 +41,7 @@ class CVRPEnv(RL4COEnvBase): """Capacitated Vehicle Routing Problem (CVRP) environment. At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to - visit any customer, the agent must go back to the depot. The reward is the -infinite unless the agent visits all the cities. + visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities. In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. Args: @@ -103,25 +103,19 @@ def _step(self, td: TensorDict) -> TensorDict: # SECTION: get done done = visited.sum(-1) == visited.size(-1) - reward = torch.ones_like(done) * float("-inf") + reward = torch.zeros_like(done) - td_step = TensorDict( + td.update( { - "next": { - "locs": td["locs"], - "demand": td["demand"], - "current_node": current_node, - "used_capacity": used_capacity, - "vehicle_capacity": td["vehicle_capacity"], - "visited": visited, - "reward": reward, - "done": done, - } - }, - td.shape, + "current_node": current_node, + "used_capacity": used_capacity, + "visited": visited, + "reward": reward, + "done": done, + } ) - td_step["next"].set("action_mask", self.get_action_mask(td_step["next"])) - return td_step + td.set("action_mask", self.get_action_mask(td)) + return td def _reset( self, @@ -134,7 +128,7 @@ def _reset( td = self.generate_data(batch_size=batch_size) batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - self.device = td.device + self.to(td.device) # Create reset TensorDict td_reset = TensorDict( @@ -210,7 +204,7 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): assert ( used_cap <= td["vehicle_capacity"] + 1e-5 ).all(), "Used more than capacity" - + def generate_data(self, batch_size) -> TensorDict: # Batch size input check batch_size = [batch_size] if isinstance(batch_size, int) else batch_size @@ -227,13 +221,12 @@ def generate_data(self, batch_size) -> TensorDict: # Generates a slightly different distribution than using torch.randint demand = ( ( - torch.FloatTensor(*batch_size, self.num_loc) + torch.FloatTensor(*batch_size, self.num_loc, device=self.device) .uniform_(self.min_demand - 1, self.max_demand - 1) .int() + 1 ) .float() - .to(self.device) ) # Support for heterogeneous capacity if provided @@ -250,7 +243,8 @@ def generate_data(self, batch_size) -> TensorDict: "capacity": capacity, }, batch_size=batch_size, - ) + device=self.device, + ) @staticmethod def load_data(fpath, batch_size=[]): @@ -258,7 +252,7 @@ def load_data(fpath, batch_size=[]): Normalize demand by capacity to be in [0, 1] """ td_load = load_npz_to_tensordict(fpath) - td_load.set_("demand", td_load["demand"] / td_load["capacity"][:, None]) + td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None]) return td_load def _make_spec(self, td_params: TensorDict): @@ -286,7 +280,6 @@ def _make_spec(self, td_params: TensorDict): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/mpdp.py b/rl4co/envs/routing/mpdp.py similarity index 93% rename from rl4co/envs/mpdp.py rename to rl4co/envs/routing/mpdp.py index 77e6c7cc..1a83eeb0 100644 --- a/rl4co/envs/mpdp.py +++ b/rl4co/envs/routing/mpdp.py @@ -22,7 +22,7 @@ class MPDPEnv(RL4COEnvBase): The goal is to pick up and deliver all the packages while satisfying the precedence constraints. When an agent goes back to the depot, a new agent is spawned. In the min-max version, the goal is to minimize the maximum tour length among all agents. - The reward is the -infinite unless the agent visits all the cities. + The reward is 0 unless the agent visits all the cities. In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. Args: @@ -111,36 +111,25 @@ def _step(self, td: TensorDict) -> TensorDict: # Get done and reward done = visited.all(dim=-1, keepdim=True).squeeze(-1) - reward = torch.ones_like(done) * float( - "-inf" - ) # reward calculated via `get_reward` for now + reward = torch.zeros_like(done) - td_step = TensorDict( + td.update( { - "next": { - "locs": td["locs"], - "visited": visited, - "lengths": td["lengths"], - "count_depot": td["count_depot"], - "agent_idx": agent_idx, - "cur_coord": cur_coord, - "to_delivery": to_delivery, - "left_request": td["left_request"], - "depot_distance": depot_distance, - "remain_sum_paired_distance": remain_sum_paired_distance, - "remain_pickup_max_distance": remain_pickup_max_distance, - "remain_delivery_max_distance": remain_delivery_max_distance, - "add_pd_distance": td["add_pd_distance"], - "longest_lengths": td["longest_lengths"], - "i": td["i"] + 1, - "done": done, - "reward": reward, - } - }, - td.shape, + "visited": visited, + "agent_idx": agent_idx, + "cur_coord": cur_coord, + "to_delivery": to_delivery, + "depot_distance": depot_distance, + "remain_sum_paired_distance": remain_sum_paired_distance, + "remain_pickup_max_distance": remain_pickup_max_distance, + "remain_delivery_max_distance": remain_delivery_max_distance, + "i": td["i"] + 1, + "done": done, + "reward": reward, + } ) - td_step["next"].set("action_mask", self.get_action_mask(td_step["next"])) - return td_step + td.set("action_mask", self.get_action_mask(td)) + return td def _reset( self, @@ -154,7 +143,7 @@ def _reset( if td is None or td.is_empty(): td = self.generate_data(batch_size=batch_size) - self.device = td.device + self.to(td.device) # NOTE: this is a hack to get the agent_num # agent_num = td["agent_num"][0].item() if agent_num is None else agent_num @@ -425,7 +414,6 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.int64, ), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/mtsp.py b/rl4co/envs/routing/mtsp.py similarity index 94% rename from rl4co/envs/mtsp.py rename to rl4co/envs/routing/mtsp.py index e85fa624..46f37849 100644 --- a/rl4co/envs/mtsp.py +++ b/rl4co/envs/routing/mtsp.py @@ -112,25 +112,22 @@ def _step(td: TensorDict) -> TensorDict: # The reward is the negative of the max_subtour_length (minmax objective) reward = -max_subtour_length - # The output must be written in a ``"next"`` entry - return TensorDict( + td.update( { - "next": { - "locs": td["locs"], - "num_agents": td["num_agents"], - "max_subtour_length": max_subtour_length, - "current_length": current_length, - "agent_idx": cur_agent_idx, - "first_node": first_node, - "current_node": current_node, - "i": td["i"] + 1, - "action_mask": available, - "reward": reward, - "done": done, - } - }, - td.shape, + "max_subtour_length": max_subtour_length, + "current_length": current_length, + "agent_idx": cur_agent_idx, + "first_node": first_node, + "current_node": current_node, + "i": td["i"] + 1, + "action_mask": available, + "reward": reward, + "done": done, + } ) + + return td + def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: # Initialize data @@ -214,7 +211,6 @@ def _make_spec(self, td_params: TensorDict): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/op.py b/rl4co/envs/routing/op.py similarity index 94% rename from rl4co/envs/op.py rename to rl4co/envs/routing/op.py index 42d00a9b..485964d3 100644 --- a/rl4co/envs/op.py +++ b/rl4co/envs/routing/op.py @@ -90,29 +90,23 @@ def _step(self, td: TensorDict) -> TensorDict: # Done if went back to depot (except if it's the first step, since we start at the depot) done = (current_node.squeeze(-1) == 0) & (td["i"] > 0) - # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here - reward = torch.ones_like(done) * float("-inf") - - td_step = TensorDict( + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here + reward = torch.zeros_like(done) + + td.update( { - "next": { - "locs": td["locs"], - "prize": td["prize"], - "tour_length": tour_length, - "current_loc": current_loc, - "max_length": td["max_length"], - "current_node": current_node, - "visited": visited, - "current_total_prize": current_total_prize, - "i": td["i"] + 1, - "reward": reward, - "done": done, - } - }, - td.shape, + "tour_length": tour_length, + "current_loc": current_loc, + "current_node": current_node, + "visited": visited, + "current_total_prize": current_total_prize, + "i": td["i"] + 1, + "reward": reward, + "done": done, + } ) - td_step["next"].set("action_mask", self.get_action_mask(td_step["next"])) - return td_step + td.set("action_mask", self.get_action_mask(td)) + return td def _reset( self, @@ -124,8 +118,7 @@ def _reset( batch_size = self.batch_size if td is None else td["locs"].shape[:-2] if td is None or td.is_empty(): td = self.generate_data(batch_size=batch_size) - self.device = td.device - + self.to(td.device) # Add depot to locs locs_with_depot = torch.cat((td["depot"][:, None, :], td["locs"]), -2) @@ -322,7 +315,6 @@ def _make_spec(self, td_params: TensorDict): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/pctsp.py b/rl4co/envs/routing/pctsp.py similarity index 93% rename from rl4co/envs/pctsp.py rename to rl4co/envs/routing/pctsp.py index b4c3204f..f001d754 100644 --- a/rl4co/envs/pctsp.py +++ b/rl4co/envs/routing/pctsp.py @@ -78,31 +78,26 @@ def _step(self, td: TensorDict) -> TensorDict: # Update visited visited = td["visited"].scatter(-1, current_node[..., None], 1) - # Done and reward. Calculation is done outside hence set -inf + # Done and reward done = (td["i"] > 0) & (current_node == 0) - reward = torch.ones_like(cur_total_prize) * float("-inf") - - td_step = TensorDict( + + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here + reward = torch.zeros_like(done) + + # Update state + td.update( { - "next": { - "locs": td["locs"], - "current_node": current_node, - "expected_prize": td["expected_prize"], - "real_prize": td["real_prize"], - "penalty": td["penalty"], - "cur_total_prize": cur_total_prize, - "cur_total_penalty": cur_total_penalty, - "visited": visited, - "prize_required": td["prize_required"], - "i": td["i"] + 1, - "reward": reward, - "done": done, - }, - }, - batch_size=td.batch_size, + "current_node": current_node, + "cur_total_prize": cur_total_prize, + "cur_total_penalty": cur_total_penalty, + "visited": visited, + "i": td["i"] + 1, + "reward": reward, + "done": done, + } ) - td_step["next"].set("action_mask", self.get_action_mask(td_step["next"])) - return td_step + td.set("action_mask", self.get_action_mask(td)) + return td def _reset( self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None @@ -111,7 +106,7 @@ def _reset( batch_size = self.batch_size if td is None else td["locs"].shape[:-2] if td is None or td.is_empty(): td = self.generate_data(batch_size=batch_size) - self.device = td.device + self.to(td.device) locs = torch.cat([td["depot"][..., None, :], td["locs"]], dim=-2) expected_prize = td["deterministic_prize"] @@ -323,7 +318,6 @@ def _make_spec(self, td_params: TensorDict): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/pdp.py b/rl4co/envs/routing/pdp.py similarity index 93% rename from rl4co/envs/pdp.py rename to rl4co/envs/routing/pdp.py index 7d6a309d..b845e093 100644 --- a/rl4co/envs/pdp.py +++ b/rl4co/envs/routing/pdp.py @@ -71,25 +71,22 @@ def _step(td: TensorDict) -> TensorDict: # We are done there are no unvisited locations done = torch.count_nonzero(available, dim=-1) == 0 - # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here - reward = torch.ones_like(done) * float("-inf") + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here + reward = torch.zeros_like(done) - # The output must be written in a ``"next"`` entry - return TensorDict( + # Update step + td.update( { - "next": { - "locs": td["locs"], - "current_node": current_node, - "available": available, - "to_deliver": to_deliver, - "i": td["i"] + 1, - "action_mask": action_mask, - "reward": reward, - "done": done, - } - }, - td.shape, + "current_node": current_node, + "available": available, + "to_deliver": to_deliver, + "i": td["i"] + 1, + "action_mask": action_mask, + "reward": reward, + "done": done, + } ) + return td def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: if batch_size is None: @@ -98,7 +95,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict if td is None or td.is_empty(): td = self.generate_data(batch_size=batch_size) - self.device = td.device + self.to(td.device) locs = torch.cat((td["depot"][:, None, :], td["locs"]), -2) @@ -170,7 +167,6 @@ def _make_spec(self, td_params: TensorDict): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/sdvrp.py b/rl4co/envs/routing/sdvrp.py similarity index 89% rename from rl4co/envs/sdvrp.py rename to rl4co/envs/routing/sdvrp.py index 30d582ee..415d61c5 100644 --- a/rl4co/envs/sdvrp.py +++ b/rl4co/envs/routing/sdvrp.py @@ -10,9 +10,9 @@ UnboundedDiscreteTensorSpec, ) -from rl4co.envs.cvrp import CVRPEnv from rl4co.utils.ops import gather_by_index from rl4co.utils.pylogger import get_pylogger +from .cvrp import CVRPEnv log = get_pylogger(__name__) @@ -84,27 +84,24 @@ def _step(self, td: TensorDict) -> TensorDict: -1, current_node, -delivered_demand ) - # Get done and reward (-inf since we get it outside) + # Get done done = ~(demand_with_depot > 0).any(-1) - reward = torch.ones_like(done) * float("-inf") - - td_step = TensorDict( + + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here + reward = torch.zeros_like(done) + + # Update state + td.update( { - "next": { - "locs": td["locs"], - "demand": td["demand"], - "demand_with_depot": demand_with_depot, - "current_node": current_node, - "used_capacity": used_capacity, - "vehicle_capacity": td["vehicle_capacity"], - "reward": reward, - "done": done, - } - }, - td.shape, + "demand_with_depot": demand_with_depot, + "current_node": current_node, + "used_capacity": used_capacity, + "reward": reward, + "done": done, + } ) - td_step["next"].set("action_mask", self.get_action_mask(td_step["next"])) - return td_step + td.set("action_mask", self.get_action_mask(td)) + return td def _reset( self, @@ -117,7 +114,7 @@ def _reset( if td is None or td.is_empty(): td = self.generate_data(batch_size=batch_size) - self.device = td["locs"].device + self.to(td.device) # Create reset TensorDict reset_td = TensorDict( @@ -211,7 +208,6 @@ def _make_spec(self, td_params: TensorDict): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/spctsp.py b/rl4co/envs/routing/spctsp.py similarity index 95% rename from rl4co/envs/spctsp.py rename to rl4co/envs/routing/spctsp.py index e6bfb757..2f0c2e5e 100644 --- a/rl4co/envs/spctsp.py +++ b/rl4co/envs/routing/spctsp.py @@ -1,5 +1,5 @@ -from rl4co.envs.pctsp import PCTSPEnv from rl4co.utils.pylogger import get_pylogger +from .pctsp import PCTSPEnv log = get_pylogger(__name__) diff --git a/rl4co/envs/tsp.py b/rl4co/envs/routing/tsp.py similarity index 87% rename from rl4co/envs/tsp.py rename to rl4co/envs/routing/tsp.py index 3d7d53a6..a4d7e784 100644 --- a/rl4co/envs/tsp.py +++ b/rl4co/envs/routing/tsp.py @@ -21,7 +21,7 @@ class TSPEnv(RL4COEnvBase): """ Traveling Salesman Problem environment - At each step, the agent chooses a city to visit. The reward is the -infinite unless the agent visits all the cities. + At each step, the agent chooses a city to visit. The reward is 0 unless the agent visits all the cities. In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. Args: @@ -50,41 +50,38 @@ def __init__( @staticmethod def _step(td: TensorDict) -> TensorDict: current_node = td["action"] - first_node = current_node if batch_to_scalar(td["i"]) == 0 else td["first_node"] + first_node = current_node if td["i"].all() == 0 else td["first_node"] - # Set not visited to 0 (i.e., we visited the node) + # # Set not visited to 0 (i.e., we visited the node) available = td["action_mask"].scatter( -1, current_node.unsqueeze(-1).expand_as(td["action_mask"]), 0 ) # We are done there are no unvisited locations - done = torch.count_nonzero(available, dim=-1) <= 0 + done = torch.sum(available, dim=-1) == 0 - # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here - reward = torch.ones_like(done) * float("-inf") + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here + reward = torch.zeros_like(done) - # The output must be written in a ``"next"`` entry - return TensorDict( + td.update( { - "next": { - "locs": td["locs"], - "first_node": first_node, - "current_node": current_node, - "i": td["i"] + 1, - "action_mask": available, - "reward": reward, - "done": done, - } + "first_node": first_node, + "current_node": current_node, + "i": td["i"] + 1, + "action_mask": available, + "reward": reward, + "done": done, }, - td.shape, ) + return td def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: # Initialize locations init_locs = td["locs"] if td is not None else None if batch_size is None: batch_size = self.batch_size if init_locs is None else init_locs.shape[:-2] - self.device = device = init_locs.device if init_locs is not None else self.device + device = init_locs.device if init_locs is not None else self.device + self.to(device) if init_locs is None: init_locs = self.generate_data(batch_size=batch_size).to(device)["locs"] batch_size = [batch_size] if isinstance(batch_size, int) else batch_size @@ -106,6 +103,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict "current_node": current_node, "i": i, "action_mask": available, + "reward": torch.zeros((*batch_size, 1), dtype=torch.float32), }, batch_size=batch_size, ) @@ -137,7 +135,6 @@ def _make_spec(self, td_params): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/scheduling/__init__.py b/rl4co/envs/scheduling/__init__.py new file mode 100644 index 00000000..e161da27 --- /dev/null +++ b/rl4co/envs/scheduling/__init__.py @@ -0,0 +1,2 @@ +from rl4co.envs.scheduling.ffsp import FFSPEnv +from rl4co.envs.scheduling.smtwtp import SMTWTPEnv \ No newline at end of file diff --git a/rl4co/envs/ffsp.py b/rl4co/envs/scheduling/ffsp.py similarity index 92% rename from rl4co/envs/ffsp.py rename to rl4co/envs/scheduling/ffsp.py index ff97d5d7..428a4616 100644 --- a/rl4co/envs/ffsp.py +++ b/rl4co/envs/scheduling/ffsp.py @@ -138,31 +138,26 @@ def _step(self, td: TensorDict) -> TensorDict: job_mask[job_enable] = 0 reward = td["reward"] - - return TensorDict( + + # Updated state + td.update( { - "next": { - "stage_table": td["stage_table"], - "machine_table": td["machine_table"], - "time_idx": time_idx, - "sub_time_idx": sub_time_idx, - "batch_idx": batch_idx, - "machine_idx": machine_idx, - "schedule": schedule, - "machine_wait_step": machine_wait_step, - "job_location": job_location, - "job_wait_step": job_wait_step, - "job_duration": td["job_duration"], - "reward": reward, - "finish": finish, - # Update variables - "job_mask": job_mask, - "stage_idx": stage_idx, - "stage_machine_idx": stage_machine_idx, - } - }, - td.shape, + "time_idx": time_idx, + "sub_time_idx": sub_time_idx, + "batch_idx": batch_idx, + "machine_idx": machine_idx, + "schedule": schedule, + "machine_wait_step": machine_wait_step, + "job_location": job_location, + "job_wait_step": job_wait_step, + "reward": reward, + "finish": finish, + "job_mask": job_mask, + "stage_idx": stage_idx, + "stage_machine_idx": stage_machine_idx, + } ) + return td def _reset( self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None @@ -321,7 +316,6 @@ def _make_spec(self, td_params: TensorDict): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, diff --git a/rl4co/envs/smtwtp.py b/rl4co/envs/scheduling/smtwtp.py similarity index 91% rename from rl4co/envs/smtwtp.py rename to rl4co/envs/scheduling/smtwtp.py index 41f5e5af..979b7c1e 100644 --- a/rl4co/envs/smtwtp.py +++ b/rl4co/envs/scheduling/smtwtp.py @@ -22,7 +22,7 @@ class SMTWTPEnv(RL4COEnvBase): SMTWTP is a scheduling problem in which a set of jobs must be processed on a single machine. Each job i has a processing time, a weight, and a due date. The objective is to minimize the sum of the weighted tardiness of all jobs, where the weighted tardiness of a job is defined as the product of its weight and the duration by which its completion time exceeds its due date. - At each step, the agent chooses a job to process. The reward is the -infinite unless the agent processes all the jobs. + At each step, the agent chooses a job to process. The reward is 0 unless the agent processes all the jobs. In that case, the reward is (-)objective value of the processing order: maximizing the reward is equivalent to minimizing the objective. Args: @@ -80,25 +80,19 @@ def _step(td: TensorDict) -> TensorDict: # We are done there are no unvisited locations done = torch.count_nonzero(available, dim=-1) <= 0 - # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here - reward = torch.ones_like(done) * float("-inf") + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here + reward = torch.zeros_like(done) - # The output must be written in a ``"next"`` entry - return TensorDict( + td.update( { - "next": { - "job_due_time": td["job_due_time"], - "job_weight": td["job_weight"], - "job_process_time": td["job_process_time"], - "current_job": current_job, - "current_time": current_time, - "action_mask": available, - "reward": reward, - "done": done, - } - }, - td.shape, + "current_job": current_job, + "current_time": current_time, + "action_mask": available, + "reward": reward, + "done": done, + } ) + return td def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: # Initialization @@ -106,9 +100,10 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size = self.batch_size if td is None else td["job_due_time"].shape[:-1] batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - self.device = device = ( + device = ( td["job_due_time"].device if td is not None else self.device ) + self.to(device) td = self.generate_data(batch_size) if td is None else td @@ -170,7 +165,6 @@ def _make_spec(self, td_params: TensorDict = None): ), shape=(), ) - self.input_spec = self.observation_spec.clone() self.action_spec = BoundedTensorSpec( shape=(1,), dtype=torch.int64, From 64c5c6be1b4bbfa4e723ceb0801e618230eba12e Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:51:00 +0900 Subject: [PATCH 02/13] Refactoring; update to new TorchRL --- rl4co/envs/__init__.py | 13 +++++++++++-- rl4co/envs/common/base.py | 22 +++++++++++----------- rl4co/envs/eda/__init__.py | 2 +- rl4co/envs/eda/dpp.py | 4 ++-- rl4co/envs/routing/atsp.py | 4 +--- rl4co/envs/routing/cvrp.py | 17 +++++++---------- rl4co/envs/routing/mtsp.py | 3 +-- rl4co/envs/routing/op.py | 4 ++-- rl4co/envs/routing/pctsp.py | 4 ++-- rl4co/envs/routing/sdvrp.py | 7 ++++--- rl4co/envs/routing/spctsp.py | 1 + rl4co/envs/routing/tsp.py | 1 - rl4co/envs/scheduling/__init__.py | 2 +- rl4co/envs/scheduling/ffsp.py | 2 +- rl4co/envs/scheduling/smtwtp.py | 4 +--- 15 files changed, 46 insertions(+), 44 deletions(-) diff --git a/rl4co/envs/__init__.py b/rl4co/envs/__init__.py index eb9faef5..1b4fce9f 100644 --- a/rl4co/envs/__init__.py +++ b/rl4co/envs/__init__.py @@ -5,12 +5,21 @@ from rl4co.envs.eda import DPPEnv, MDPPEnv # Routing -from rl4co.envs.routing import ATSPEnv, CVRPEnv, MTSPEnv, OPEnv, PCTSPEnv, PDPEnv, SDVRPEnv, SPCTSPEnv, TSPEnv +from rl4co.envs.routing import ( + ATSPEnv, + CVRPEnv, + MTSPEnv, + OPEnv, + PCTSPEnv, + PDPEnv, + SDVRPEnv, + SPCTSPEnv, + TSPEnv, +) # Scheduling from rl4co.envs.scheduling import FFSPEnv, SMTWTPEnv - # Register environments ENV_REGISTRY = { "atsp": ATSPEnv, diff --git a/rl4co/envs/common/base.py b/rl4co/envs/common/base.py index b65774f0..85eb52a1 100644 --- a/rl4co/envs/common/base.py +++ b/rl4co/envs/common/base.py @@ -1,5 +1,5 @@ from os.path import join as pjoin -from typing import Optional, Iterable +from typing import Iterable, Optional import torch @@ -40,7 +40,7 @@ def __init__( val_dataloader_names: list = None, test_dataloader_names: list = None, check_solution: bool = True, - _torchrl_mode: bool = False, # TODO + _torchrl_mode: bool = False, # TODO seed: int = None, device: str = "cpu", **kwargs, @@ -89,7 +89,7 @@ def get_multiple_dataloader_names(f, names): def step(self, td: TensorDict) -> TensorDict: """Step function to call at each step of the episode containing an action. - If `_torchrl_mode` is True, we call `_torchrl_step` instead which set the + If `_torchrl_mode` is True, we call `_torchrl_step` instead which set the `next` key of the TensorDict to the next state - this is the usual way to do it in TorchRL, but inefficient in our case """ @@ -100,11 +100,11 @@ def step(self, td: TensorDict) -> TensorDict: else: # Since we simplify the syntax return self._torchrl_step(td) - + def _torchrl_step(self, td: TensorDict) -> TensorDict: """See :meth:`super().step` for more details. This is the usual way to do it in TorchRL, but inefficient in our case - + Note: Here we clone the TensorDict to avoid recursion error, since we allow for directly updating the TensorDict in the step function @@ -113,12 +113,12 @@ def _torchrl_step(self, td: TensorDict) -> TensorDict: self._assert_tensordict_shape(td) next_preset = td.get("next", None) - next_tensordict = self._step(td.clone()) # NOTE: we clone to avoid recursion error + next_tensordict = self._step( + td.clone() + ) # NOTE: we clone to avoid recursion error next_tensordict = self._step_proc_data(next_tensordict) if next_preset is not None: - next_tensordict.update( - next_preset.exclude(*next_tensordict.keys(True, True)) - ) + next_tensordict.update(next_preset.exclude(*next_tensordict.keys(True, True))) td.set("next", next_tensordict) return td @@ -214,11 +214,11 @@ def _set_seed(self, seed: Optional[int]): """Set the seed for the environment""" rng = torch.manual_seed(seed) self.rng = rng - + def to(self, device): """Override `to` device method for safety against `None` device (may be found in `TensorDict`))""" if device is None: - return self + return self else: return super().to(device) diff --git a/rl4co/envs/eda/__init__.py b/rl4co/envs/eda/__init__.py index f306ecee..da7f45e2 100644 --- a/rl4co/envs/eda/__init__.py +++ b/rl4co/envs/eda/__init__.py @@ -1,2 +1,2 @@ from rl4co.envs.eda.dpp import DPPEnv -from rl4co.envs.eda.mdpp import MDPPEnv \ No newline at end of file +from rl4co.envs.eda.mdpp import MDPPEnv diff --git a/rl4co/envs/eda/dpp.py b/rl4co/envs/eda/dpp.py index 633af37f..fe88572c 100644 --- a/rl4co/envs/eda/dpp.py +++ b/rl4co/envs/eda/dpp.py @@ -102,7 +102,7 @@ def _step(self, td: TensorDict) -> TensorDict: # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here reward = torch.zeros_like(done) - + td.update( { "i": td["i"] + 1, @@ -119,7 +119,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size = self.batch_size if td is None else td.batch_size device = td.device if td is not None else self.device self.to(device) - + # We allow loading the initial observation from a dataset for faster loading if td is None: td = self.generate_data(batch_size=batch_size) diff --git a/rl4co/envs/routing/atsp.py b/rl4co/envs/routing/atsp.py index e00a5cee..dcfb720e 100644 --- a/rl4co/envs/routing/atsp.py +++ b/rl4co/envs/routing/atsp.py @@ -86,9 +86,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size = ( self.batch_size if cost_matrix is None else cost_matrix.shape[:-2] ) - device = ( - cost_matrix.device if cost_matrix is not None else self.device - ) + device = cost_matrix.device if cost_matrix is not None else self.device self.to(device) if cost_matrix is None: cost_matrix = self.generate_data(batch_size=batch_size).to(device)[ diff --git a/rl4co/envs/routing/cvrp.py b/rl4co/envs/routing/cvrp.py index 14d43c20..edd79d38 100644 --- a/rl4co/envs/routing/cvrp.py +++ b/rl4co/envs/routing/cvrp.py @@ -204,7 +204,7 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): assert ( used_cap <= td["vehicle_capacity"] + 1e-5 ).all(), "Used more than capacity" - + def generate_data(self, batch_size) -> TensorDict: # Batch size input check batch_size = [batch_size] if isinstance(batch_size, int) else batch_size @@ -220,14 +220,11 @@ def generate_data(self, batch_size) -> TensorDict: # Demand sampling Following Kool et al. (2019) # Generates a slightly different distribution than using torch.randint demand = ( - ( - torch.FloatTensor(*batch_size, self.num_loc, device=self.device) - .uniform_(self.min_demand - 1, self.max_demand - 1) - .int() - + 1 - ) - .float() - ) + torch.FloatTensor(*batch_size, self.num_loc, device=self.device) + .uniform_(self.min_demand - 1, self.max_demand - 1) + .int() + + 1 + ).float() # Support for heterogeneous capacity if provided if not isinstance(self.capacity, torch.Tensor): @@ -244,7 +241,7 @@ def generate_data(self, batch_size) -> TensorDict: }, batch_size=batch_size, device=self.device, - ) + ) @staticmethod def load_data(fpath, batch_size=[]): diff --git a/rl4co/envs/routing/mtsp.py b/rl4co/envs/routing/mtsp.py index 46f37849..7b835589 100644 --- a/rl4co/envs/routing/mtsp.py +++ b/rl4co/envs/routing/mtsp.py @@ -125,9 +125,8 @@ def _step(td: TensorDict) -> TensorDict: "done": done, } ) - + return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: # Initialize data diff --git a/rl4co/envs/routing/op.py b/rl4co/envs/routing/op.py index 485964d3..70f06263 100644 --- a/rl4co/envs/routing/op.py +++ b/rl4co/envs/routing/op.py @@ -92,7 +92,7 @@ def _step(self, td: TensorDict) -> TensorDict: # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here reward = torch.zeros_like(done) - + td.update( { "tour_length": tour_length, @@ -102,7 +102,7 @@ def _step(self, td: TensorDict) -> TensorDict: "current_total_prize": current_total_prize, "i": td["i"] + 1, "reward": reward, - "done": done, + "done": done, } ) td.set("action_mask", self.get_action_mask(td)) diff --git a/rl4co/envs/routing/pctsp.py b/rl4co/envs/routing/pctsp.py index f001d754..ca0f4863 100644 --- a/rl4co/envs/routing/pctsp.py +++ b/rl4co/envs/routing/pctsp.py @@ -80,10 +80,10 @@ def _step(self, td: TensorDict) -> TensorDict: # Done and reward done = (td["i"] > 0) & (current_node == 0) - + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here reward = torch.zeros_like(done) - + # Update state td.update( { diff --git a/rl4co/envs/routing/sdvrp.py b/rl4co/envs/routing/sdvrp.py index 415d61c5..1fbf182d 100644 --- a/rl4co/envs/routing/sdvrp.py +++ b/rl4co/envs/routing/sdvrp.py @@ -12,6 +12,7 @@ from rl4co.utils.ops import gather_by_index from rl4co.utils.pylogger import get_pylogger + from .cvrp import CVRPEnv log = get_pylogger(__name__) @@ -84,12 +85,12 @@ def _step(self, td: TensorDict) -> TensorDict: -1, current_node, -delivered_demand ) - # Get done + # Get done done = ~(demand_with_depot > 0).any(-1) - + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here reward = torch.zeros_like(done) - + # Update state td.update( { diff --git a/rl4co/envs/routing/spctsp.py b/rl4co/envs/routing/spctsp.py index 2f0c2e5e..e39a31b1 100644 --- a/rl4co/envs/routing/spctsp.py +++ b/rl4co/envs/routing/spctsp.py @@ -1,4 +1,5 @@ from rl4co.utils.pylogger import get_pylogger + from .pctsp import PCTSPEnv log = get_pylogger(__name__) diff --git a/rl4co/envs/routing/tsp.py b/rl4co/envs/routing/tsp.py index a4d7e784..7db63541 100644 --- a/rl4co/envs/routing/tsp.py +++ b/rl4co/envs/routing/tsp.py @@ -11,7 +11,6 @@ ) from rl4co.envs.common.base import RL4COEnvBase -from rl4co.envs.common.utils import batch_to_scalar from rl4co.utils.ops import gather_by_index, get_tour_length from rl4co.utils.pylogger import get_pylogger diff --git a/rl4co/envs/scheduling/__init__.py b/rl4co/envs/scheduling/__init__.py index e161da27..a9c5144b 100644 --- a/rl4co/envs/scheduling/__init__.py +++ b/rl4co/envs/scheduling/__init__.py @@ -1,2 +1,2 @@ from rl4co.envs.scheduling.ffsp import FFSPEnv -from rl4co.envs.scheduling.smtwtp import SMTWTPEnv \ No newline at end of file +from rl4co.envs.scheduling.smtwtp import SMTWTPEnv diff --git a/rl4co/envs/scheduling/ffsp.py b/rl4co/envs/scheduling/ffsp.py index 428a4616..fb1dd5fe 100644 --- a/rl4co/envs/scheduling/ffsp.py +++ b/rl4co/envs/scheduling/ffsp.py @@ -138,7 +138,7 @@ def _step(self, td: TensorDict) -> TensorDict: job_mask[job_enable] = 0 reward = td["reward"] - + # Updated state td.update( { diff --git a/rl4co/envs/scheduling/smtwtp.py b/rl4co/envs/scheduling/smtwtp.py index 979b7c1e..d65d8d07 100644 --- a/rl4co/envs/scheduling/smtwtp.py +++ b/rl4co/envs/scheduling/smtwtp.py @@ -100,9 +100,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size = self.batch_size if td is None else td["job_due_time"].shape[:-1] batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - device = ( - td["job_due_time"].device if td is not None else self.device - ) + device = td["job_due_time"].device if td is not None else self.device self.to(device) td = self.generate_data(batch_size) if td is None else td From 29103b3f6548cede3f16038f3d2aca835e56940c Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:51:51 +0900 Subject: [PATCH 03/13] Python 3.11 support --- pyproject.toml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8b4feee5..bf0fbcf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,14 +23,14 @@ dynamic = ["version"] license = {file = "LICENSE"} -# TODO: allow new Python versions https://github.com/kaist-silab/rl4co/issues/95 -requires-python = ">=3.8, <3.11" # https://github.com/kaist-silab/rl4co/issues/90 +requires-python = ">=3.8" classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Intended Audience :: Developers", @@ -41,10 +41,9 @@ classifiers = [ # TODO: allow new TorchRL / TensorDict versions https://github.com/kaist-silab/rl4co/issues/95 dependencies = [ - "torch>=2.0.0,<2.1.0", # Possibly TorchRL problem on Windows with older version - "torchrl==0.1.1", - "tensordict==0.1.1", - "lightning>=2.0.5", + "torchrl>=0.2.0", + "tensordict>=0.2.0", + "lightning>=2.1.0", "hydra-core", "hydra-colorlog", "omegaconf", From b2af00fb41d14e9f5cb8c4f8640f9f05cc80abe4 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:52:20 +0900 Subject: [PATCH 04/13] [Config] Better defaults --- configs/trainer/default.yaml | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index 84df21f4..e3344212 100644 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -4,16 +4,9 @@ _target_: rl4co.utils.trainer.RL4COTrainer default_root_dir: ${paths.output_dir} gradient_clip_val: 1.0 -accelerator: "gpu" +accelerator: "auto" precision: "16-mixed" -# Fast distributed training: comment out to use on single GPU -# devices: 1 # change number of devices -strategy: - _target_: lightning.pytorch.strategies.DDPStrategy - find_unused_parameters: True - gradient_as_bucket_view: True - # perform a validation loop every N training epochs check_val_every_n_epoch: 1 From ca78bb9f7156155e5344c745ec325125ffa146da Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:53:37 +0900 Subject: [PATCH 05/13] [Data] faster dataset instantiation --- rl4co/data/dataset.py | 69 +++++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/rl4co/data/dataset.py b/rl4co/data/dataset.py index 5d22c603..afcf8bee 100644 --- a/rl4co/data/dataset.py +++ b/rl4co/data/dataset.py @@ -1,3 +1,4 @@ + import torch from tensordict.tensordict import TensorDict @@ -5,9 +6,40 @@ class TensorDictDataset(Dataset): + """Dataset compatible with TensorDicts. + Uses more CPU and has similar performance in loading to list comprehension, but is faster in instantiation + than :class:`TensorDictDatasetList` (more than 10x faster). + """ + + def __init__(self, td: TensorDict): + self.data = td + + def __len__(self): + return len(self.data) + + def __getitems__(self, index): + # Tricks: + # - batched data loading with `__getitems__` for faster loading + # - avoid directly indexing TensorDicts for faster loading + return TensorDict( + {key: item[index] for key, item in self.data.items()}, + batch_size=torch.Size([len(index)]), + _run_checks=False, # faster this way + ) + + def add_key(self, key, value): + return self.data.update({key: value}) # native method + + +def tensordict_collate_fn(x): + """Equivalent to collating with `lambda x: x`""" + return x + + +class TensorDictDatasetList(Dataset): """Dataset compatible with TensorDicts. It is better to "disassemble" the TensorDict into a list of dicts. - See :class:`tensordict_collate_fn` for more details. + See :class:`tensordict_collate_fn_list` for more details. Note: Check out the issue on tensordict for more details: @@ -16,21 +48,24 @@ class TensorDictDataset(Dataset): but uses > 3x more CPU. """ - def __init__(self, data: TensorDict): + def __init__(self, td: TensorDict): + self.data_len = td.batch_size[0] self.data = [ - {key: value[i] for key, value in data.items()} for i in range(data.shape[0]) + {key: value[i] for key, value in td.items()} for i in range(self.data_len) ] def __len__(self): - return len(self.data) + return self.data_len def __getitem__(self, idx): return self.data[idx] + def add_key(self, key, value): + return ExtraKeyDataset(self, value, key_name=key) -def tensordict_collate_fn(batch): - """Collate function compatible with TensorDicts. - Reassemble the list of dicts into a TensorDict; seems to be way more efficient than using a TensorDictDataset. + +def tensordict_collate_fn_list(batch): + """Collate function compatible with TensorDicts that reassembles a list of dicts. Note: Check out the issue on tensordict for more details: @@ -40,24 +75,28 @@ def tensordict_collate_fn(batch): """ return TensorDict( {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()}, - batch_size=len(batch), + batch_size=torch.Size([len(batch)]), + device=batch[0].device, + _run_checks=False, ) -class ExtraKeyDataset(Dataset): +class ExtraKeyDataset(TensorDictDatasetList): """Dataset that includes an extra key to add to the data dict. This is useful for adding a REINFORCE baseline reward to the data dict. + Note that this is faster to instantiate than using list comprehension. """ - def __init__(self, dataset: TensorDictDataset, extra: torch.Tensor): + def __init__( + self, dataset: TensorDictDatasetList, extra: torch.Tensor, key_name="extra" + ): + self.data_len = len(dataset) + assert self.data_len == len(extra), "Data and extra must be same length" self.data = dataset.data self.extra = extra - assert len(self.data) == len(self.extra), "Data and extra must be same length" - - def __len__(self): - return len(self.data) + self.key_name = key_name def __getitem__(self, idx): data = self.data[idx] - data["extra"] = self.extra[idx] + data[self.key_name] = self.extra[idx] return data From 75fcf05a7f996638a66bbc78fd95043c9611f4af Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:54:34 +0900 Subject: [PATCH 06/13] [Optimization] add torch.jit.script --- rl4co/utils/ops.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/rl4co/utils/ops.py b/rl4co/utils/ops.py index 2034d0de..669e4a12 100644 --- a/rl4co/utils/ops.py +++ b/rl4co/utils/ops.py @@ -6,7 +6,6 @@ from torch import Tensor -# @torch.jit.script def _batchify_single( x: Union[Tensor, TensorDict], repeats: int ) -> Union[Tensor, TensorDict]: @@ -61,7 +60,6 @@ def unbatchify( return x -# @torch.jit.script def gather_by_index(src, idx, dim=1, squeeze=True): """Gather elements from src by index idx along specified dim @@ -76,13 +74,13 @@ def gather_by_index(src, idx, dim=1, squeeze=True): return src.gather(dim, idx).squeeze() if squeeze else src.gather(dim, idx) -# @torch.jit.script +@torch.jit.script def get_distance(x: Tensor, y: Tensor): """Euclidean distance between two tensors of shape `[..., n, dim]`""" return (x - y).norm(p=2, dim=-1) -# @torch.jit.script +@torch.jit.script def get_tour_length(ordered_locs): """Compute the total tour distance for a batch of ordered tours. Computes the L2 norm between each pair of consecutive nodes in the tour and sums them up. From 931005ae069738a1cadb968fcf5751da8d285e88 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:56:42 +0900 Subject: [PATCH 07/13] [BugFix] PPO reward detach --- rl4co/models/rl/ppo/ppo.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/rl4co/models/rl/ppo/ppo.py b/rl4co/models/rl/ppo/ppo.py index 6f3cc950..af7915ce 100644 --- a/rl4co/models/rl/ppo/ppo.py +++ b/rl4co/models/rl/ppo/ppo.py @@ -6,6 +6,7 @@ from torch.utils.data import DataLoader +from rl4co.data.dataset import TensorDictDataset, tensordict_collate_fn from rl4co.envs.common.base import RL4COEnvBase from rl4co.models.rl.common.base import RL4COLitModule from rl4co.utils.pylogger import get_pylogger @@ -124,8 +125,8 @@ def shared_step( ): # Evaluate old actions, log probabilities, and rewards with torch.no_grad(): - td = self.env.reset(batch) - out = self.policy(td, self.env, phase=phase, return_actions=True) + td = self.env.reset(batch) # note: clone needed for dataloader + out = self.policy(td.clone(), self.env, phase=phase, return_actions=True) if phase == "train": batch_size = out["actions"].shape[0] @@ -146,12 +147,17 @@ def shared_step( td.set("reward", out["reward"]) td.set("action", out["actions"]) + dataset = TensorDictDataset(td) dataloader = DataLoader( - td, batch_size=mini_batch_size, shuffle=True, collate_fn=lambda x: x + dataset, + batch_size=mini_batch_size, + shuffle=True, + collate_fn=tensordict_collate_fn, ) for _ in range(self.ppo_cfg["ppo_epochs"]): # PPO inner epoch, K for sub_td in dataloader: + previous_reward = sub_td["reward"].view(-1, 1) ll, entropy = self.policy.evaluate_action( sub_td, action=sub_td["action"] ) @@ -163,7 +169,7 @@ def shared_step( # Compute the advantage value_pred = self.critic(sub_td) # [batch, 1] - adv = sub_td["reward"].view(-1, 1) - value_pred.detach() + adv = previous_reward - value_pred.detach() # Normalize advantage if self.ppo_cfg["normalize_adv"]: @@ -181,7 +187,7 @@ def shared_step( ).mean() # compute value function loss - value_loss = F.huber_loss(value_pred, sub_td["reward"].view(-1, 1)) + value_loss = F.huber_loss(value_pred, previous_reward) # compute total loss loss = ( From 7d4ec127e77fcf977b0285e23da0e91a55ed9cae Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:57:50 +0900 Subject: [PATCH 08/13] [Update] use `add_key` for dataset --- rl4co/models/rl/reinforce/baselines.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/rl4co/models/rl/reinforce/baselines.py b/rl4co/models/rl/reinforce/baselines.py index 8117edd1..8d345660 100644 --- a/rl4co/models/rl/reinforce/baselines.py +++ b/rl4co/models/rl/reinforce/baselines.py @@ -6,10 +6,9 @@ from scipy.stats import ttest_rel from torch.utils.data import DataLoader -from tqdm.auto import tqdm from rl4co import utils -from rl4co.data.dataset import ExtraKeyDataset, tensordict_collate_fn +from rl4co.data.dataset import tensordict_collate_fn from rl4co.models.rl.common.critic import CriticNetwork log = utils.get_pylogger(__name__) @@ -81,6 +80,7 @@ def eval(self, td, reward, env=None): class MeanBaseline(REINFORCEBaseline): """Mean baseline: return mean of reward as baseline""" + def __new__(cls, **kw): return ExponentialBaseline(beta=0.0, **kw) @@ -158,13 +158,11 @@ class RolloutBaseline(REINFORCEBaseline): Args: bl_alpha: Alpha value for the baseline T-test - progress_bar: Whether to show progress bar for rollout """ - def __init__(self, bl_alpha=0.05, progress_bar=False, **kw): + def __init__(self, bl_alpha=0.05, **kw): super(RolloutBaseline, self).__init__() self.bl_alpha = bl_alpha - self.progress_bar = progress_bar def setup(self, *args, **kw): self._update_model(*args, **kw) @@ -235,9 +233,7 @@ def eval_model(batch): dl = DataLoader(dataset, batch_size=batch_size, collate_fn=tensordict_collate_fn) - rewards = torch.cat( - [eval_model(batch) for batch in tqdm(dl, disable=not self.progress_bar)], 0 - ) + rewards = torch.cat([eval_model(batch) for batch in dl], 0) return rewards def wrap_dataset(self, dataset, env, batch_size=64, device="cpu", **kw): @@ -253,7 +249,7 @@ def wrap_dataset(self, dataset, env, batch_size=64, device="cpu", **kw): .detach() .cpu() ) - return ExtraKeyDataset(dataset, rewards) + return dataset.add_key("extra", rewards) def __getstate__(self): """Do not include datasets in state to avoid pickling issues""" From 74c48cad2ad167146aba391ae47770e737b5abc4 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:58:43 +0900 Subject: [PATCH 09/13] [Config] update --- configs/env/cvrp.yaml | 2 +- configs/env/default.yaml | 2 +- configs/env/dpp.yaml | 2 +- configs/env/mdpp.yaml | 2 +- configs/env/mtsp.yaml | 2 +- configs/env/op.yaml | 2 +- configs/env/pctsp.yaml | 2 +- configs/env/pdp.yaml | 2 +- configs/env/sdvrp.yaml | 2 +- configs/env/spctsp.yaml | 2 +- configs/env/tsp.yaml | 2 +- configs/experiment/base.yaml | 3 --- 12 files changed, 11 insertions(+), 14 deletions(-) diff --git a/configs/env/cvrp.yaml b/configs/env/cvrp.yaml index 2b98bd92..e8598620 100644 --- a/configs/env/cvrp.yaml +++ b/configs/env/cvrp.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.cvrp.CVRPEnv +_target_: rl4co.envs.CVRPEnv name: cvrp num_loc: 20 diff --git a/configs/env/default.yaml b/configs/env/default.yaml index a9ce9bf3..66fa95a6 100644 --- a/configs/env/default.yaml +++ b/configs/env/default.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.tsp.TSPEnv +_target_: rl4co.envs.TSPEnv name: tsp num_loc: 20 diff --git a/configs/env/dpp.yaml b/configs/env/dpp.yaml index 0456187c..51a3c74b 100644 --- a/configs/env/dpp.yaml +++ b/configs/env/dpp.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.dpp.DPPEnv +_target_: rl4co.envs.DPPEnv name: dpp max_decaps: 20 diff --git a/configs/env/mdpp.yaml b/configs/env/mdpp.yaml index 57194a58..df790181 100644 --- a/configs/env/mdpp.yaml +++ b/configs/env/mdpp.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.mdpp.MDPPEnv +_target_: rl4co.envs.MDPPEnv name: mdpp max_decaps: 20 diff --git a/configs/env/mtsp.yaml b/configs/env/mtsp.yaml index 50cadba5..e24e0dca 100644 --- a/configs/env/mtsp.yaml +++ b/configs/env/mtsp.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.mtsp.MTSPEnv +_target_: rl4co.envs.MTSPEnv name: mtsp num_loc: 20 diff --git a/configs/env/op.yaml b/configs/env/op.yaml index e71bcba1..08d8d86d 100644 --- a/configs/env/op.yaml +++ b/configs/env/op.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.op.OPEnv +_target_: rl4co.envs.OPEnv name: op num_loc: 20 diff --git a/configs/env/pctsp.yaml b/configs/env/pctsp.yaml index 3e92aac1..a05fc1f7 100644 --- a/configs/env/pctsp.yaml +++ b/configs/env/pctsp.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.pctsp.PCTSPEnv +_target_: rl4co.envs.PCTSPEnv name: pctsp num_loc: 20 diff --git a/configs/env/pdp.yaml b/configs/env/pdp.yaml index 71f12e01..ba5236a9 100644 --- a/configs/env/pdp.yaml +++ b/configs/env/pdp.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.pdp.PDPEnv +_target_: rl4co.envs.PDPEnv name: pdp num_loc: 20 diff --git a/configs/env/sdvrp.yaml b/configs/env/sdvrp.yaml index cb5f81a0..6ecdd4ce 100644 --- a/configs/env/sdvrp.yaml +++ b/configs/env/sdvrp.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.sdvrp.SDVRPEnv +_target_: rl4co.envs.SDVRPEnv name: sdvrp num_loc: 20 diff --git a/configs/env/spctsp.yaml b/configs/env/spctsp.yaml index ac1387e6..1a239237 100644 --- a/configs/env/spctsp.yaml +++ b/configs/env/spctsp.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.spctsp.SPCTSPEnv +_target_: rl4co.envs.SPCTSPEnv name: spctsp num_loc: 20 diff --git a/configs/env/tsp.yaml b/configs/env/tsp.yaml index 6d3b87fc..e853a203 100644 --- a/configs/env/tsp.yaml +++ b/configs/env/tsp.yaml @@ -1,4 +1,4 @@ -_target_: rl4co.envs.tsp.TSPEnv +_target_: rl4co.envs.TSPEnv name: tsp diff --git a/configs/experiment/base.yaml b/configs/experiment/base.yaml index 4fae8d25..4cfe47d0 100644 --- a/configs/experiment/base.yaml +++ b/configs/experiment/base.yaml @@ -17,9 +17,6 @@ defaults: # that are automatically generated with seed following Kool et al. (2019). env: num_loc: 50 - data_dir: ${paths.root_dir}/data/tsp - val_file: tsp${env.num_loc}_val_seed4321.npz - test_file: tsp${env.num_loc}_test_seed1234.npz # Logging: we use Wandb in this case logger: From 0ac16cdf92fd95307598bc13d4a54ab5f2d4b07e Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 16:59:17 +0900 Subject: [PATCH 10/13] =?UTF-8?q?[Formatting]=20Black+Ruff=20=F0=9F=8E=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/_theme/rl4co/extensions/lightning.py | 5 ++- .../rl4co/extensions/pytorch_tutorials.py | 29 ++++++++++++---- rl4co/__init__.py | 2 +- rl4co/models/__init__.py | 1 - rl4co/models/nn/utils.py | 4 ++- rl4co/models/rl/common/base.py | 18 ++++++---- rl4co/models/zoo/__init__.py | 1 - rl4co/models/zoo/ham/model.py | 4 +-- rl4co/models/zoo/ham/policy.py | 3 +- rl4co/models/zoo/mdam/__init__.py | 2 +- rl4co/models/zoo/mdam/decoder.py | 33 +++++++++---------- rl4co/models/zoo/mdam/model.py | 27 ++++++++------- rl4co/models/zoo/mdam/policy.py | 17 +++++----- rl4co/models/zoo/pomo/model.py | 10 +++--- rl4co/tasks/train.py | 5 +-- rl4co/utils/optim_helpers.py | 1 - rl4co/utils/trainer.py | 12 +++++-- 17 files changed, 100 insertions(+), 74 deletions(-) diff --git a/docs/_theme/rl4co/extensions/lightning.py b/docs/_theme/rl4co/extensions/lightning.py index d0633b41..8cba7fda 100644 --- a/docs/_theme/rl4co/extensions/lightning.py +++ b/docs/_theme/rl4co/extensions/lightning.py @@ -13,10 +13,7 @@ # limitations under the License. from docutils import nodes from docutils.statemachine import StringList -from sphinx.util.docutils import SphinxDirective - from pt_lightning_sphinx_theme.extensions.pytorch_tutorials import ( - cardnode, CustomCalloutItemDirective, CustomCardItemDirective, DisplayItemDirective, @@ -24,7 +21,9 @@ ReactGreeter, SlackButton, TwoColumns, + cardnode, ) +from sphinx.util.docutils import SphinxDirective class tutoriallistnode(nodes.General, nodes.Element): diff --git a/docs/_theme/rl4co/extensions/pytorch_tutorials.py b/docs/_theme/rl4co/extensions/pytorch_tutorials.py index 97dbc6e6..12b0e73d 100644 --- a/docs/_theme/rl4co/extensions/pytorch_tutorials.py +++ b/docs/_theme/rl4co/extensions/pytorch_tutorials.py @@ -34,9 +34,8 @@ from docutils import nodes from docutils.parsers.rst import Directive, directives from docutils.statemachine import StringList -from sphinx.util.docutils import SphinxDirective - from pt_lightning_sphinx_theme.extensions.react import get_react_component_rst +from sphinx.util.docutils import SphinxDirective try: FileNotFoundError @@ -272,11 +271,23 @@ def run(self): image_class = "" if "image_center" in self.options: - image = "" + image = ( + "" + ) image_class = "image-center" elif "image_right" in self.options: - image = "" + image = ( + "" + ) image_class = "image-right" else: image = "" @@ -371,7 +382,11 @@ def run(self): raise # return [] callout_rst = get_react_component_rst( - "LikeButtonWithTitle", width=width, margin=margin, title=title, padding=padding + "LikeButtonWithTitle", + width=width, + margin=margin, + title=title, + padding=padding, ) callout_list = StringList(callout_rst.split("\n")) callout = nodes.paragraph() @@ -427,7 +442,9 @@ def run(self): print(e) raise return [] - callout_rst = SLACK_TEMPLATE.format(align=align, title=title, margin=margin, width=width) + callout_rst = SLACK_TEMPLATE.format( + align=align, title=title, margin=margin, width=width + ) callout_list = StringList(callout_rst.split("\n")) callout = nodes.paragraph() self.state.nested_parse(callout_list, self.content_offset, callout) diff --git a/rl4co/__init__.py b/rl4co/__init__.py index 4dc8ce10..95407eb1 100644 --- a/rl4co/__init__.py +++ b/rl4co/__init__.py @@ -1 +1 @@ -__version__ = "0.2.4.dev1" +__version__ = "0.3.0dev0" diff --git a/rl4co/models/__init__.py b/rl4co/models/__init__.py index c923dbe7..b4b794db 100644 --- a/rl4co/models/__init__.py +++ b/rl4co/models/__init__.py @@ -1,7 +1,6 @@ from rl4co.models.zoo.active_search import ActiveSearch from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy - from rl4co.models.zoo.common.search import SearchBase from rl4co.models.zoo.eas import EAS, EASEmb, EASLay from rl4co.models.zoo.ham import ( diff --git a/rl4co/models/nn/utils.py b/rl4co/models/nn/utils.py index acdc622e..6ab21c7c 100644 --- a/rl4co/models/nn/utils.py +++ b/rl4co/models/nn/utils.py @@ -19,7 +19,9 @@ def get_log_likelihood(log_p, actions, mask, return_sum: bool = True): if mask is not None: log_p[~mask] = 0 - assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!" + assert ( + log_p > -1000 + ).data.all(), "Logprobs should not be -inf, check sampling procedure!" # Calculate log_likelihood if return_sum: diff --git a/rl4co/models/rl/common/base.py b/rl4co/models/rl/common/base.py index 3ec0ce22..7a7d5b20 100644 --- a/rl4co/models/rl/common/base.py +++ b/rl4co/models/rl/common/base.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Union, Iterable +from typing import Any, Iterable, Union import torch import torch.nn as nn @@ -149,7 +149,6 @@ def setup(self, stage="fit"): self.data_cfg["test_data_size"], phase="test" ) self.dataloader_names = None - self.setup_loggers() self.post_setup_hook() @@ -214,12 +213,16 @@ def configure_optimizers(self, parameters=None): def log_metrics(self, metric_dict: dict, phase: str, dataloader_idx: int = None): """Log metrics to logger and progress bar""" - metrics = getattr(self, f"{phase}_metrics") + metrics = getattr(self, f"{phase}_metrics") dataloader_name = "" if dataloader_idx is not None and self.dataloader_names is not None: - dataloader_name = "/" + self.dataloader_names[dataloader_idx] + dataloader_name = "/" + self.dataloader_names[dataloader_idx] metrics = { - f"{phase}/{k}{dataloader_name}": v.mean() if isinstance(v, torch.Tensor) else v for k, v in metric_dict.items() if k in metrics + f"{phase}/{k}{dataloader_name}": v.mean() + if isinstance(v, torch.Tensor) + else v + for k, v in metric_dict.items() + if k in metrics } log_on_step = self.log_on_step if phase == "train" else False on_epoch = False if phase == "train" else True @@ -292,7 +295,10 @@ def _dataloader(self, dataset, batch_size, shuffle=False): self.dataloader_names = list(dataset.keys()) else: self.dataloader_names = [f"{i}" for i in range(len(dataset))] - return [self._dataloader_single(ds, batch_size, shuffle) for ds in dataset.values()] + return [ + self._dataloader_single(ds, batch_size, shuffle) + for ds in dataset.values() + ] else: return self._dataloader_single(dataset, batch_size, shuffle) diff --git a/rl4co/models/zoo/__init__.py b/rl4co/models/zoo/__init__.py index c923dbe7..b4b794db 100644 --- a/rl4co/models/zoo/__init__.py +++ b/rl4co/models/zoo/__init__.py @@ -1,7 +1,6 @@ from rl4co.models.zoo.active_search import ActiveSearch from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy - from rl4co.models.zoo.common.search import SearchBase from rl4co.models.zoo.eas import EAS, EASEmb, EASLay from rl4co.models.zoo.ham import ( diff --git a/rl4co/models/zoo/ham/model.py b/rl4co/models/zoo/ham/model.py index a95f558b..aa402e84 100644 --- a/rl4co/models/zoo/ham/model.py +++ b/rl4co/models/zoo/ham/model.py @@ -7,7 +7,7 @@ class HeterogeneousAttentionModel(REINFORCE): - """Heterogenous Attention Model for solving the Pickup and Delivery Problem based on + """Heterogenous Attention Model for solving the Pickup and Delivery Problem based on REINFORCE: https://arxiv.org/abs/2110.02634. Args: @@ -20,7 +20,7 @@ class HeterogeneousAttentionModel(REINFORCE): """ def __init__( - self, + self, env: RL4COEnvBase, policy: HeterogeneousAttentionModelPolicy = None, baseline: Union[REINFORCEBaseline, str] = "rollout", diff --git a/rl4co/models/zoo/ham/policy.py b/rl4co/models/zoo/ham/policy.py index d2ae43c8..a9cd4ea1 100644 --- a/rl4co/models/zoo/ham/policy.py +++ b/rl4co/models/zoo/ham/policy.py @@ -1,4 +1,3 @@ -import torch.nn as nn from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy from rl4co.models.zoo.ham.encoder import GraphHeterogeneousAttentionEncoder @@ -41,4 +40,4 @@ def __init__( num_heads=num_heads, normalization=normalization, **kwargs, - ) \ No newline at end of file + ) diff --git a/rl4co/models/zoo/mdam/__init__.py b/rl4co/models/zoo/mdam/__init__.py index 0dcc6521..2b7a14da 100644 --- a/rl4co/models/zoo/mdam/__init__.py +++ b/rl4co/models/zoo/mdam/__init__.py @@ -1,2 +1,2 @@ +from .model import MDAM from .policy import MDAMPolicy -from .model import MDAM \ No newline at end of file diff --git a/rl4co/models/zoo/mdam/decoder.py b/rl4co/models/zoo/mdam/decoder.py index 87fd0dee..8e1b9daf 100644 --- a/rl4co/models/zoo/mdam/decoder.py +++ b/rl4co/models/zoo/mdam/decoder.py @@ -1,15 +1,14 @@ import math -from typing import Union from dataclasses import dataclass -from tensordict import TensorDict +from typing import Union import torch import torch.nn as nn import torch.nn.functional as F +from tensordict import TensorDict from rl4co.envs import RL4COEnvBase - from rl4co.models.nn.attention import LogitAttention from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding from rl4co.models.nn.utils import decode_probs, get_log_likelihood @@ -67,8 +66,7 @@ def __init__( self.project_node_embeddings = nn.ModuleList(self.project_node_embeddings) self.project_fixed_context = [ - nn.Linear(embedding_dim, embedding_dim, bias=False) - for _ in range(num_paths) + nn.Linear(embedding_dim, embedding_dim, bias=False) for _ in range(num_paths) ] self.project_fixed_context = nn.ModuleList(self.project_fixed_context) @@ -79,8 +77,7 @@ def __init__( self.project_step_context = nn.ModuleList(self.project_step_context) self.project_out = [ - nn.Linear(embedding_dim, embedding_dim, bias=False) - for _ in range(num_paths) + nn.Linear(embedding_dim, embedding_dim, bias=False) for _ in range(num_paths) ] self.project_out = nn.ModuleList(self.project_out) @@ -108,15 +105,15 @@ def __init__( self.shrink_size = shrink_size def forward( - self, - td: TensorDict, - encoded_inputs: torch.Tensor, - env: Union[str, RL4COEnvBase], - attn, - V, - h_old, - **decoder_kwargs - ): + self, + td: TensorDict, + encoded_inputs: torch.Tensor, + env: Union[str, RL4COEnvBase], + attn, + V, + h_old, + **decoder_kwargs, + ): # SECTION: Decoder first step: calculate for the decoder divergence loss # Cost list and log likelihood list along with path output_list = [] @@ -261,7 +258,9 @@ def _get_log_p(self, fixed, td, path_index, normalize=True): step_context = self.context[path_index]( fixed.node_embeddings, td ) # [batch, embed_dim] - glimpse_q = fixed.graph_context + step_context.unsqueeze(1).to(fixed.graph_context.device) + glimpse_q = fixed.graph_context + step_context.unsqueeze(1).to( + fixed.graph_context.device + ) # Compute keys and values for the nodes ( diff --git a/rl4co/models/zoo/mdam/model.py b/rl4co/models/zoo/mdam/model.py index b0696cf3..5ab935a6 100644 --- a/rl4co/models/zoo/mdam/model.py +++ b/rl4co/models/zoo/mdam/model.py @@ -1,4 +1,3 @@ - from typing import Union from rl4co.envs.common.base import RL4COEnvBase @@ -8,10 +7,10 @@ class MDAM(REINFORCE): - """ Multi-Decoder Attention Model (MDAM) is a model - to train multiple diverse policies, which effectively increases the chance of finding + """Multi-Decoder Attention Model (MDAM) is a model + to train multiple diverse policies, which effectively increases the chance of finding good solutions compared with existing methods that train only one policy. - Reference link: https://arxiv.org/abs/2012.10638; + Reference link: https://arxiv.org/abs/2012.10638; Implementation reference: https://github.com/liangxinedu/MDAM. Args: @@ -24,15 +23,15 @@ class MDAM(REINFORCE): """ def __init__( - self, - env: RL4COEnvBase, - policy: MDAMPolicy = None, - baseline: Union[REINFORCEBaseline, str] = "rollout", - policy_kwargs={}, - baseline_kwargs={}, - **kwargs - ): + self, + env: RL4COEnvBase, + policy: MDAMPolicy = None, + baseline: Union[REINFORCEBaseline, str] = "rollout", + policy_kwargs={}, + baseline_kwargs={}, + **kwargs, + ): if policy is None: - policy = MDAMPolicy(env.name, **policy_kwargs) + policy = MDAMPolicy(env.name, **policy_kwargs) - super().__init__(env, policy, baseline, baseline_kwargs, **kwargs) \ No newline at end of file + super().__init__(env, policy, baseline, baseline_kwargs, **kwargs) diff --git a/rl4co/models/zoo/mdam/policy.py b/rl4co/models/zoo/mdam/policy.py index 7a8b7c04..30299dd5 100644 --- a/rl4co/models/zoo/mdam/policy.py +++ b/rl4co/models/zoo/mdam/policy.py @@ -1,26 +1,25 @@ -import torch.nn as nn from typing import Union from tensordict import TensorDict -from rl4co.envs import RL4COEnvBase, get_env +from rl4co.envs import RL4COEnvBase, get_env from rl4co.models.nn.env_embeddings import env_init_embedding +from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy from rl4co.models.zoo.mdam.decoder import Decoder from rl4co.models.zoo.mdam.encoder import GraphAttentionEncoder -from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy from rl4co.utils.pylogger import get_pylogger log = get_pylogger(__name__) class MDAMPolicy(AutoregressivePolicy): - """ Multi-Decoder Attention Model (MDAM) policy. + """Multi-Decoder Attention Model (MDAM) policy. Args: """ - + def __init__( - self, + self, env_name: str, embedding_dim: int = 128, num_encoder_layers: int = 3, @@ -35,13 +34,13 @@ def __init__( embed_dim=embedding_dim, num_layers=num_encoder_layers, normalization=normalization, - **kwargs + **kwargs, ), decoder=Decoder( env_name=env_name, embedding_dim=embedding_dim, num_heads=num_heads, - **kwargs + **kwargs, ), embedding_dim=embedding_dim, num_encoder_layers=num_encoder_layers, @@ -84,4 +83,4 @@ def forward( "entropy": kl_divergence, "actions": actions if return_actions else None, } - return out \ No newline at end of file + return out diff --git a/rl4co/models/zoo/pomo/model.py b/rl4co/models/zoo/pomo/model.py index eefff8ae..1d707df2 100644 --- a/rl4co/models/zoo/pomo/model.py +++ b/rl4co/models/zoo/pomo/model.py @@ -56,7 +56,9 @@ def __init__( for phase in ["train", "val", "test"]: self.set_decode_type_multistart(phase) - def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None): + def shared_step( + self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None + ): td = self.env.reset(batch) n_aug, n_start = self.num_augment, self.num_starts n_start = get_num_starts(td) if n_start is None else n_start @@ -102,10 +104,10 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: in out.update({"max_aug_reward": max_aug_reward}) if out.get("actions", None) is not None: - actions_ = out["best_multistart_actions"] if n_start > 1 else out["actions"] - out.update( - {"best_aug_actions": gather_by_index(actions_, max_idxs)} + actions_ = ( + out["best_multistart_actions"] if n_start > 1 else out["actions"] ) + out.update({"best_aug_actions": gather_by_index(actions_, max_idxs)}) metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) return {"loss": out.get("loss", None), **metrics} diff --git a/rl4co/tasks/train.py b/rl4co/tasks/train.py index 8d04a01c..6b628470 100644 --- a/rl4co/tasks/train.py +++ b/rl4co/tasks/train.py @@ -9,11 +9,12 @@ from lightning.pytorch.loggers import Logger from omegaconf import DictConfig -pyrootutils.setup_root(__file__, indicator=".gitignore", pythonpath=True) - from rl4co import utils from rl4co.utils import RL4COTrainer +pyrootutils.setup_root(__file__, indicator=".gitignore", pythonpath=True) + + log = utils.get_pylogger(__name__) diff --git a/rl4co/utils/optim_helpers.py b/rl4co/utils/optim_helpers.py index f784a62b..46367a37 100644 --- a/rl4co/utils/optim_helpers.py +++ b/rl4co/utils/optim_helpers.py @@ -1,7 +1,6 @@ import inspect import torch -import torch.nn as nn from torch.optim import Optimizer diff --git a/rl4co/utils/trainer.py b/rl4co/utils/trainer.py index a76b4e5f..790437c3 100644 --- a/rl4co/utils/trainer.py +++ b/rl4co/utils/trainer.py @@ -68,7 +68,7 @@ def __init__( except AttributeError: pass - # Configure DDP automatically + # Configure DDP automatically if multiple GPUs are available if auto_configure_ddp and strategy == "auto": if devices == "auto": n_devices = num_cuda_devices() @@ -77,7 +77,11 @@ def __init__( else: n_devices = devices if n_devices > 1: - log.info("Configuring DDP strategy automatically") + log.info( + "Configuring DDP strategy automatically with {} GPUs".format( + n_devices + ) + ) strategy = DDPStrategy( find_unused_parameters=True, # We set to True due to RL envs gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations @@ -89,7 +93,9 @@ def __init__( # Check if gradient_clip_val is set to None if gradient_clip_val is None: - log.warning("gradient_clip_val is set to None. This may lead to unstable training.") + log.warning( + "gradient_clip_val is set to None. This may lead to unstable training." + ) # We should reload dataloaders every epoch for RL training if reload_dataloaders_every_n_epochs != 1: From 8894d795be3429eaca510f5e1f4df012a3e4ddfb Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 17:03:42 +0900 Subject: [PATCH 11/13] [Tests] Python 3.11 + Apple silicon --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 21588fe3..7f10d8fc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,8 +8,8 @@ jobs: fail-fast: true max-parallel: 15 matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.8', '3.9', '3.10'] + os: [ubuntu-latest, macos-latest, windows-latest, macos-13-xlarge] + python-version: ['3.8', '3.9', '3.10', '3.11'] defaults: run: shell: bash From 27ac959aa33c7cb89774e8ebddf74f24fc4b95d3 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 5 Nov 2023 17:05:41 +0900 Subject: [PATCH 12/13] [Tests] Remove Apple silicon temporarily (paid account needed) --- .github/workflows/tests.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7f10d8fc..b9bc85df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,9 @@ jobs: fail-fast: true max-parallel: 15 matrix: - os: [ubuntu-latest, macos-latest, windows-latest, macos-13-xlarge] + # os: [ubuntu-latest, macos-latest, windows-latest, macos-13-xlarge] + # For Apple Silicon: https://github.com/actions/runner-images/issues/8439 + os: [ubuntu-latest, macos-latest, windows-latest] python-version: ['3.8', '3.9', '3.10', '3.11'] defaults: run: From 7b61d6cc969e7b459781eacbb817af74536dc17f Mon Sep 17 00:00:00 2001 From: junyoung park Date: Tue, 7 Nov 2023 15:49:06 +0900 Subject: [PATCH 13/13] [Feat] Implement MatNet --- rl4co/models/zoo/matnet/__init__.py | 0 rl4co/models/zoo/matnet/decoder.py | 52 +++++ rl4co/models/zoo/matnet/encoder.py | 309 ++++++++++++++++++++++++++++ rl4co/models/zoo/matnet/model.py | 39 ++++ rl4co/models/zoo/matnet/policy.py | 61 ++++++ 5 files changed, 461 insertions(+) create mode 100644 rl4co/models/zoo/matnet/__init__.py create mode 100644 rl4co/models/zoo/matnet/decoder.py create mode 100644 rl4co/models/zoo/matnet/encoder.py create mode 100644 rl4co/models/zoo/matnet/model.py create mode 100644 rl4co/models/zoo/matnet/policy.py diff --git a/rl4co/models/zoo/matnet/__init__.py b/rl4co/models/zoo/matnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rl4co/models/zoo/matnet/decoder.py b/rl4co/models/zoo/matnet/decoder.py new file mode 100644 index 00000000..e703bf5c --- /dev/null +++ b/rl4co/models/zoo/matnet/decoder.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Tuple, Union + +import torch +import torch.nn as nn +from einops import rearrange +from rl4co.models.zoo.common.autoregressive.decoder import AutoregressiveDecoder +from rl4co.utils.ops import batchify, get_num_starts, select_start_nodes, unbatchify +from tensordict import TensorDict +from torch import Tensor + + +@dataclass +class PrecomputedCache: + node_embeddings: Tensor + graph_context: Union[Tensor, float] + glimpse_key: Tensor + glimpse_val: Tensor + logit_key: Tensor + + +class MatNetDecoder(AutoregressiveDecoder): + def _precompute_cache( + self, embeddings: Tuple[Tensor, Tensor], num_starts: int = 0, td: TensorDict = None + ): + col_emb, row_emb = embeddings + ( + glimpse_key_fixed, + glimpse_val_fixed, + logit_key, + ) = self.project_node_embeddings( + col_emb + ).chunk(3, dim=-1) + + # Optionally disable the graph context from the initial embedding as done in POMO + if self.use_graph_context: + graph_context = unbatchify( + batchify(self.project_fixed_context(col_emb.mean(1)), num_starts), + num_starts, + ) + else: + graph_context = 0 + + # Organize in a dataclass for easy access + return PrecomputedCache( + node_embeddings=row_emb, + graph_context=graph_context, + glimpse_key=glimpse_key_fixed, + glimpse_val=glimpse_val_fixed, + # logit_key=col_emb, + logit_key=logit_key, + ) \ No newline at end of file diff --git a/rl4co/models/zoo/matnet/encoder.py b/rl4co/models/zoo/matnet/encoder.py new file mode 100644 index 00000000..273baa31 --- /dev/null +++ b/rl4co/models/zoo/matnet/encoder.py @@ -0,0 +1,309 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from rl4co.models.nn.ops import Normalization +from tensordict import TensorDict + + +class MatNetCrossMHA(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + bias: bool = True, + mixer_hidden_dim: int = 16, + mix1_init: float = (1 / 2) ** (1 / 2), + mix2_init: float = (1 / 16) ** (1 / 2), + ): + super().__init__() + self.embedding_dim = embedding_dim + self.num_heads = num_heads + assert ( + self.embedding_dim % num_heads == 0 + ), "embedding_dim must be divisible by num_heads" + self.head_dim = self.embedding_dim // num_heads + + self.Wq = nn.Linear(embedding_dim, embedding_dim, bias=bias) + self.Wkv = nn.Linear(embedding_dim, 2 * embedding_dim, bias=bias) + + # Score mixer + # Taken from the official MatNet implementation + # https://github.com/yd-kwon/MatNet/blob/main/ATSP/ATSP_MatNet/ATSPModel_LIB.py#L72 + mix_W1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample( + (num_heads, 2, mixer_hidden_dim) + ) + mix_b1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample( + (num_heads, mixer_hidden_dim) + ) + self.mix_W1 = nn.Parameter(mix_W1) + self.mix_b1 = nn.Parameter(mix_b1) + + mix_W2 = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample( + (num_heads, mixer_hidden_dim, 1) + ) + mix_b2 = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample( + (num_heads, 1) + ) + self.mix_W2 = nn.Parameter(mix_W2) + self.mix_b2 = nn.Parameter(mix_b2) + + self.out_proj = nn.Linear(embedding_dim, embedding_dim, bias=bias) + + def forward(self, q_input, kv_input, dmat): + """ + + Args: + q_input (Tensor): [b, m, d] + kv_input (Tensor): [b, n, d] + dmat (Tensor): [b, m, n] + + Returns: + Tensor: [b, m, d] + """ + + b, m, n = dmat.shape + + q = rearrange( + self.Wq(q_input), "b m (h d) -> b h m d", h=self.num_heads + ) # [b, h, m, d] + k, v = rearrange( + self.Wkv(kv_input), "b n (two h d) -> two b h n d", two=2, h=self.num_heads + ).unbind( + dim=0 + ) # [b, h, n, d] + + scale = math.sqrt(q.size(-1)) # scale factor + attn_scores = torch.matmul(q, k.transpose(2, 3)) / scale # [b, h, m, n] + mix_attn_scores = torch.stack( + [attn_scores, dmat[:, None, :, :].expand(b, self.num_heads, m, n)], dim=-1 + ) # [b, h, m, n, 2] + + mix_attn_scores = ( + ( + torch.matmul( + F.relu( + torch.matmul(mix_attn_scores.transpose(1, 2), self.mix_W1) + + self.mix_b1[None, None, :, None, :] + ), + self.mix_W2, + ) + + self.mix_b2[None, None, :, None, :] + ) + .transpose(1, 2) + .squeeze(-1) + ) # [b, h, m, n] + + attn_probs = F.softmax(mix_attn_scores, dim=-1) + out = torch.matmul(attn_probs, v) + return self.out_proj(rearrange(out, "b h s d -> b s (h d)")) + + +class MatNetMHA(nn.Module): + def __init__(self, embedding_dim: int, num_heads: int, bias: bool = True): + super().__init__() + self.row_encoding_block = MatNetCrossMHA(embedding_dim, num_heads, bias) + self.col_encoding_block = MatNetCrossMHA(embedding_dim, num_heads, bias) + + def forward(self, row_emb, col_emb, dmat): + """ + Args: + row_emb (Tensor): [b, m, d] + col_emb (Tensor): [b, n, d] + dmat (Tensor): [b, m, n] + + Returns: + Updated row_emb (Tensor): [b, m, d] + Updated col_emb (Tensor): [b, n, d] + """ + + updated_row_emb = self.row_encoding_block(row_emb, col_emb, dmat) + updated_col_emb = self.col_encoding_block( + col_emb, row_emb, dmat.transpose(-2, -1) + ) + return updated_row_emb, updated_col_emb + + +class MatNetMHALayer(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + bias: bool = True, + feed_forward_hidden: int = 512, + normalization: Optional[str] = "instance", + ): + super().__init__() + self.MHA = MatNetMHA(embedding_dim, num_heads, bias) + + self.F_a = nn.ModuleDict( + { + "norm1": Normalization(embedding_dim, normalization), + "ffn": nn.Sequential( + nn.Linear(embedding_dim, feed_forward_hidden), + nn.ReLU(), + nn.Linear(feed_forward_hidden, embedding_dim), + ), + "norm2": Normalization(embedding_dim, normalization), + } + ) + + self.F_b = nn.ModuleDict( + { + "norm1": Normalization(embedding_dim, normalization), + "ffn": nn.Sequential( + nn.Linear(embedding_dim, feed_forward_hidden), + nn.ReLU(), + nn.Linear(feed_forward_hidden, embedding_dim), + ), + "norm2": Normalization(embedding_dim, normalization), + } + ) + + def forward(self, row_emb, col_emb, dmat): + """ + Args: + row_emb (Tensor): [b, m, d] + col_emb (Tensor): [b, n, d] + dmat (Tensor): [b, m, n] + + Returns: + Updated row_emb (Tensor): [b, m, d] + Updated col_emb (Tensor): [b, n, d] + """ + + row_emb_out, col_emb_out = self.MHA(row_emb, col_emb, dmat) + + row_emb_out = self.F_a["norm1"](row_emb + row_emb_out) + row_emb_out = self.F_a["norm2"](row_emb_out + self.F_a["ffn"](row_emb_out)) + + col_emb_out = self.F_b["norm1"](col_emb + col_emb_out) + col_emb_out = self.F_b["norm2"](col_emb_out + self.F_b["ffn"](col_emb_out)) + return row_emb_out, col_emb_out + + +class MatNetMHANetwork(nn.Module): + def __init__( + self, + embedding_dim: int = 128, + num_heads: int = 8, + num_layers: int = 3, + normalization: str = "batch", + feed_forward_hidden: int = 512, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + MatNetMHALayer( + num_heads=num_heads, + embedding_dim=embedding_dim, + feed_forward_hidden=feed_forward_hidden, + normalization=normalization, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, row_emb, col_emb, dmat): + """ + Args: + row_emb (Tensor): [b, m, d] + col_emb (Tensor): [b, n, d] + dmat (Tensor): [b, m, n] + + Returns: + Updated row_emb (Tensor): [b, m, d] + Updated col_emb (Tensor): [b, n, d] + """ + + for layer in self.layers: + row_emb, col_emb = layer(row_emb, col_emb, dmat) + return row_emb, col_emb + + +class MatNetATSPInitEmbedding(nn.Module): + """ + Preparing the initial row and column embeddings for ATSP. + + Reference: + https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L51 + + + """ + + def __init__(self, embedding_dim: int, mode: str = "RandomOneHot") -> None: + super().__init__() + + self.embedding_dim = embedding_dim + assert mode in { + "RandomOneHot", + "Random", + }, "mode must be one of ['RandomOneHot', 'Random']" + self.mode = mode + + self.dmat_proj = nn.Linear(1, 2 * embedding_dim, bias=False) + self.row_proj = nn.Linear(embedding_dim * 4, embedding_dim, bias=False) + self.col_proj = nn.Linear(embedding_dim * 4, embedding_dim, bias=False) + + def forward(self, td: TensorDict): + dmat = td["cost_matrix"] # [b, n, n] + b, n, _ = dmat.shape + + row_emb = torch.zeros(b, n, self.embedding_dim, device=dmat.device) + + if self.mode == "RandomOneHot": + # MatNet uses one-hot encoding for column embeddings + # https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L60 + + col_emb = torch.zeros(b, n, self.embedding_dim, device=dmat.device) + rand = torch.rand(b, n) + rand_idx = rand.argsort(dim=1) + b_idx = torch.arange(b)[:, None].expand(b, n) + n_idx = torch.arange(n)[None, :].expand(b, n) + col_emb[b_idx, n_idx, rand_idx] = 1.0 + + elif self.mode == "Random": + col_emb = torch.rand(b, n, self.embedding_dim, device=dmat.device) + else: + raise NotImplementedError + + return row_emb, col_emb, dmat + + +class MatNetEncoder(nn.Module): + def __init__( + self, + embedding_dim: int = 256, + num_heads: int = 16, + num_layers: int = 5, + normalization: str = "instance", + feed_forward_hidden: int = 512, + init_embedding: nn.Module = None, + init_embedding_kwargs: dict = None, + ): + super().__init__() + + if init_embedding is None: + init_embedding = MatNetATSPInitEmbedding( + embedding_dim, **init_embedding_kwargs + ) + + self.init_embedding = init_embedding + self.net = MatNetMHANetwork( + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_layers, + normalization=normalization, + feed_forward_hidden=feed_forward_hidden, + ) + + def forward(self, td): + row_emb, col_emb, dmat = self.init_embedding(td) + row_emb, col_emb = self.net(row_emb, col_emb, dmat) + + embedding = (row_emb, col_emb) + init_embedding = None + return embedding, init_embedding # match output signature for the AR policy class diff --git a/rl4co/models/zoo/matnet/model.py b/rl4co/models/zoo/matnet/model.py new file mode 100644 index 00000000..1af9cace --- /dev/null +++ b/rl4co/models/zoo/matnet/model.py @@ -0,0 +1,39 @@ +from typing import Any, Union +from rl4co.models.zoo.matnet.policy import MatNetPolicy + +import torch.nn as nn + +from rl4co.models.zoo.pomo.model import POMO +from rl4co.envs.common.base import RL4COEnvBase + + +class MatNet(POMO): + def __init__( + self, + env: RL4COEnvBase, + policy: Union[nn.Module, MatNetPolicy] = None, + optimizer_kwargs: dict = {"lr": 4 * 1e-4, "weight_decay": 1e-6}, + lr_scheduler: str = "MultiStepLR", + lr_scheduler_kwargs: dict = {"milestones": [2001, 2101], "gamma": 0.1}, + use_dihedral_8: bool = False, + num_starts: int = None, + train_data_size: int = 10_000, + batch_size: int = 200, + policy_params: dict = {}, + model_params: dict = {}, + ): + if policy is None: + policy = MatNetPolicy(env_name=env.name, **policy_params) + + super(MatNet, self).__init__( + env=env, + policy=policy, + optimizer_kwargs=optimizer_kwargs, + lr_scheduler=lr_scheduler, + lr_scheduler_kwargs=lr_scheduler_kwargs, + use_dihedral_8=use_dihedral_8, + num_starts=num_starts, + train_data_size=train_data_size, + batch_size=batch_size, + **model_params, + ) diff --git a/rl4co/models/zoo/matnet/policy.py b/rl4co/models/zoo/matnet/policy.py new file mode 100644 index 00000000..8b4e1761 --- /dev/null +++ b/rl4co/models/zoo/matnet/policy.py @@ -0,0 +1,61 @@ +from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy +from rl4co.models.zoo.matnet.encoder import MatNetEncoder +from rl4co.models.zoo.matnet.decoder import MatNetDecoder +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class MatNetPolicy(AutoregressivePolicy): + """MatNet Policy from Kwon et al., 2021. + Reference: https://arxiv.org/abs/2106.11113 + + Warning: + This implementation is under development and subject to change. + + Args: + env_name: Name of the environment used to initialize embeddings + embedding_dim: Dimension of the node embeddings + num_encoder_layers: Number of layers in the encoder + num_heads: Number of heads in the attention layers + normalization: Normalization type in the attention layers + **kwargs: keyword arguments passed to the `AutoregressivePolicy` + + Default paarameters are adopted from the original implementation. + """ + + def __init__( + self, + env_name: str, + embedding_dim: int = 256, + num_encoder_layers: int = 5, + num_heads: int = 16, + normalization: str = "instance", + init_embedding_kwargs: dict = {"mode": "RandomOneHot"}, + use_graph_context: bool = False, + **kwargs, + ): + if env_name not in ["atsp"]: + log.error(f"env_name {env_name} is not originally implemented in MatNet") + + super(MatNetPolicy, self).__init__( + env_name=env_name, + encoder=MatNetEncoder( + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_encoder_layers, + normalization=normalization, + init_embedding_kwargs=init_embedding_kwargs, + ), + decoder=MatNetDecoder( + env_name=env_name, + embedding_dim=embedding_dim, + num_heads=num_heads, + use_graph_context=use_graph_context, + ), + embedding_dim=embedding_dim, + num_encoder_layers=num_encoder_layers, + num_heads=num_heads, + normalization=normalization, + **kwargs, + )