From 1c14c3aef72bd511d0099d4faafce284b78aa07a Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Tue, 3 Oct 2023 10:13:21 +0900 Subject: [PATCH] [Feat] support multiple dataloaders; log idx/names --- rl4co/envs/common/base.py | 48 ++++++++++++++++++++++++-- rl4co/models/rl/common/base.py | 31 ++++++++++++----- rl4co/models/rl/ppo/ppo.py | 19 +++++++--- rl4co/models/rl/reinforce/reinforce.py | 6 ++-- rl4co/models/zoo/symnco/model.py | 6 ++-- 5 files changed, 90 insertions(+), 20 deletions(-) diff --git a/rl4co/envs/common/base.py b/rl4co/envs/common/base.py index f5c9f82c..56e9edf2 100644 --- a/rl4co/envs/common/base.py +++ b/rl4co/envs/common/base.py @@ -21,6 +21,8 @@ class RL4COEnvBase(EnvBase): train_file: Name of the training file val_file: Name of the validation file test_file: Name of the test file + val_dataloader_names: Names of the dataloaders to use for validation + test_dataloader_names: Names of the dataloaders to use for testing check_solution: Whether to check the validity of the solution at the end of the episode seed: Seed for the environment device: Device to use. Generally, no need to set as tensors are updated on the fly @@ -35,6 +37,8 @@ def __init__( train_file: str = None, val_file: str = None, test_file: str = None, + val_dataloader_names: list = None, + test_dataloader_names: list = None, check_solution: bool = True, seed: int = None, device: str = "cpu", @@ -43,8 +47,39 @@ def __init__( super().__init__(device=device, batch_size=[]) self.data_dir = data_dir self.train_file = pjoin(data_dir, train_file) if train_file is not None else None - self.val_file = pjoin(data_dir, val_file) if val_file is not None else None - self.test_file = pjoin(data_dir, test_file) if test_file is not None else None + + def get_files(f): + if f is not None: + if isinstance(f, list): + return [pjoin(data_dir, _f) for _f in f] + else: + return pjoin(data_dir, f) + return None + + def get_multiple_dataloader_names(f, names): + if f is not None: + if isinstance(f, list): + if names is None: + names = [f"{i}" for i in range(len(f))] + else: + assert len(names) == len( + f + ), "Number of dataloader names must match number of files" + else: + if names is not None: + log.warning( + "Ignoring dataloader names since only one dataloader is provided" + ) + return names + + self.val_file = get_files(val_file) + self.test_file = get_files(test_file) + self.val_dataloader_names = get_multiple_dataloader_names( + self.val_file, val_dataloader_names + ) + self.test_dataloader_names = get_multiple_dataloader_names( + self.test_file, test_dataloader_names + ) self.check_solution = check_solution if seed is None: seed = torch.empty((), dtype=torch.int64).random_().item() @@ -101,7 +136,14 @@ def dataset(self, batch_size=[], phase="train", filename=None): "the dataset is fixed and the agent will not be able to explore new states" ) try: - td = self.load_data(f, batch_size) + if isinstance(f, list): + names = getattr(self, f"{phase}_dataloader_names") + return { + name: TensorDictDataset(self.load_data(_f, batch_size)) + for name, _f in zip(names, f) + } + else: + td = self.load_data(f, batch_size) except FileNotFoundError: log.error( f"Provided file name {f} not found. Make sure to provide a file in the right path first or " diff --git a/rl4co/models/rl/common/base.py b/rl4co/models/rl/common/base.py index b1927ec2..ddc0e60e 100644 --- a/rl4co/models/rl/common/base.py +++ b/rl4co/models/rl/common/base.py @@ -211,7 +211,7 @@ def configure_optimizers(self, parameters=None): "monitor": self.lr_scheduler_monitor, } - def log_metrics(self, metric_dict: dict, phase: str): + def log_metrics(self, metric_dict: dict, phase: str, dataloader_idx: int = None): """Log metrics to logger and progress bar""" metrics = getattr(self, f"{phase}_metrics") metrics = { @@ -219,7 +219,6 @@ def log_metrics(self, metric_dict: dict, phase: str): for k, v in metric_dict.items() if k in metrics } - log_on_step = self.log_on_step if phase == "train" else False on_epoch = False if phase == "train" else True self.log_dict( @@ -228,7 +227,7 @@ def log_metrics(self, metric_dict: dict, phase: str): on_epoch=on_epoch, prog_bar=True, sync_dist=True, - add_dataloader_idx=False, + add_dataloader_idx=False, # add names to dataloaders instead ) return metrics @@ -241,7 +240,7 @@ def forward(self, td, **kwargs): env = kwargs.pop("env") return self.policy(td, env, **kwargs) - def shared_step(self, batch: Any, batch_idx: int, phase: str): + def shared_step(self, batch: Any, batch_idx: int, phase: str, **kwargs): """Shared step between train/val/test. To be implemented in subclass""" raise NotImplementedError("Shared step is required to implemented in subclass") @@ -249,11 +248,15 @@ def training_step(self, batch: Any, batch_idx: int): # To use new data every epoch, we need to call reload_dataloaders_every_epoch=True in Trainer return self.shared_step(batch, batch_idx, phase="train") - def validation_step(self, batch: Any, batch_idx: int): - return self.shared_step(batch, batch_idx, phase="val") + def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None): + return self.shared_step( + batch, batch_idx, phase="val", dataloader_idx=dataloader_idx + ) - def test_step(self, batch: Any, batch_idx: int): - return self.shared_step(batch, batch_idx, phase="test") + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None): + return self.shared_step( + batch, batch_idx, phase="test", dataloader_idx=dataloader_idx + ) def train_dataloader(self): return self._dataloader( @@ -280,6 +283,18 @@ def wrap_dataset(self, dataset): return dataset def _dataloader(self, dataset, batch_size, shuffle=False): + """Handle both single datasets and list / dict of datasets""" + if isinstance(dataset, list): + return [self._dataloader_single(ds, batch_size, shuffle) for ds in dataset] + elif isinstance(dataset, dict): # we use this by default in RL4COEnvBase + return { + k: self._dataloader_single(ds, batch_size, shuffle) + for k, ds in dataset.items() + } + else: + return self._dataloader_single(dataset, batch_size, shuffle) + + def _dataloader_single(self, dataset, batch_size, shuffle=False): """The dataloader used by the trainer. This is a wrapper around the dataset with a custom collate_fn to efficiently handle TensorDicts. """ diff --git a/rl4co/models/rl/ppo/ppo.py b/rl4co/models/rl/ppo/ppo.py index b9d3d827..6f3cc950 100644 --- a/rl4co/models/rl/ppo/ppo.py +++ b/rl4co/models/rl/ppo/ppo.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + from torch.utils.data import DataLoader from rl4co.envs.common.base import RL4COEnvBase @@ -77,7 +78,9 @@ def __init__( self.automatic_optimization = False # PPO uses custom optimization routine self.critic = critic - if isinstance(mini_batch_size, float) and (mini_batch_size <= 0 or mini_batch_size > 1): + if isinstance(mini_batch_size, float) and ( + mini_batch_size <= 0 or mini_batch_size > 1 + ): default_mini_batch_fraction = 0.25 log.warning( f"mini_batch_size must be an integer or a float in the range (0, 1], got {mini_batch_size}. Setting mini_batch_size to {default_mini_batch_fraction}." @@ -116,7 +119,9 @@ def on_train_epoch_end(self): if isinstance(sch, torch.optim.lr_scheduler.MultiStepLR): sch.step() - def shared_step(self, batch: Any, batch_idx: int, phase: str): + def shared_step( + self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None + ): # Evaluate old actions, log probabilities, and rewards with torch.no_grad(): td = self.env.reset(batch) @@ -147,10 +152,14 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str): for _ in range(self.ppo_cfg["ppo_epochs"]): # PPO inner epoch, K for sub_td in dataloader: - ll, entropy = self.policy.evaluate_action(sub_td, action=sub_td["action"]) + ll, entropy = self.policy.evaluate_action( + sub_td, action=sub_td["action"] + ) # Compute the ratio of probabilities of new and old actions - ratio = torch.exp(ll.sum(dim=-1) - sub_td["log_prob"]).view(-1, 1) # [batch, 1] + ratio = torch.exp(ll.sum(dim=-1) - sub_td["log_prob"]).view( + -1, 1 + ) # [batch, 1] # Compute the advantage value_pred = self.critic(sub_td) # [batch, 1] @@ -204,5 +213,5 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str): } ) - metrics = self.log_metrics(out, phase) + metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) return {"loss": out.get("loss", None), **metrics} diff --git a/rl4co/models/rl/reinforce/reinforce.py b/rl4co/models/rl/reinforce/reinforce.py index 1003198c..f33b1193 100644 --- a/rl4co/models/rl/reinforce/reinforce.py +++ b/rl4co/models/rl/reinforce/reinforce.py @@ -48,7 +48,9 @@ def __init__( log.warning("baseline_kwargs is ignored when baseline is not a string") self.baseline = baseline - def shared_step(self, batch: Any, batch_idx: int, phase: str): + def shared_step( + self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None + ): td = self.env.reset(batch) # Perform forward pass (i.e., constructing solution and computing log-likelihoods) out = self.policy(td, self.env, phase=phase) @@ -57,7 +59,7 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str): if phase == "train": out = self.calculate_loss(td, batch, out) - metrics = self.log_metrics(out, phase) + metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) return {"loss": out.get("loss", None), **metrics} def calculate_loss( diff --git a/rl4co/models/zoo/symnco/model.py b/rl4co/models/zoo/symnco/model.py index 5e0f2e56..d0639552 100644 --- a/rl4co/models/zoo/symnco/model.py +++ b/rl4co/models/zoo/symnco/model.py @@ -66,7 +66,9 @@ def __init__( for phase in ["train", "val", "test"]: self.set_decode_type_multistart(phase) - def shared_step(self, batch: Any, batch_idx: int, phase: str): + def shared_step( + self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None + ): td = self.env.reset(batch) n_aug, n_start = self.num_augment, self.num_starts n_start = get_num_starts(td) if n_start is None else n_start @@ -130,5 +132,5 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str): } ) - metrics = self.log_metrics(out, phase) + metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) return {"loss": out.get("loss", None), **metrics}