Skip to content

Commit

Permalink
[Feat] support multiple dataloaders; log idx/names
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Oct 3, 2023
1 parent e59d177 commit 1c14c3a
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 20 deletions.
48 changes: 45 additions & 3 deletions rl4co/envs/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class RL4COEnvBase(EnvBase):
train_file: Name of the training file
val_file: Name of the validation file
test_file: Name of the test file
val_dataloader_names: Names of the dataloaders to use for validation
test_dataloader_names: Names of the dataloaders to use for testing
check_solution: Whether to check the validity of the solution at the end of the episode
seed: Seed for the environment
device: Device to use. Generally, no need to set as tensors are updated on the fly
Expand All @@ -35,6 +37,8 @@ def __init__(
train_file: str = None,
val_file: str = None,
test_file: str = None,
val_dataloader_names: list = None,
test_dataloader_names: list = None,
check_solution: bool = True,
seed: int = None,
device: str = "cpu",
Expand All @@ -43,8 +47,39 @@ def __init__(
super().__init__(device=device, batch_size=[])
self.data_dir = data_dir
self.train_file = pjoin(data_dir, train_file) if train_file is not None else None
self.val_file = pjoin(data_dir, val_file) if val_file is not None else None
self.test_file = pjoin(data_dir, test_file) if test_file is not None else None

def get_files(f):
if f is not None:
if isinstance(f, list):
return [pjoin(data_dir, _f) for _f in f]
else:
return pjoin(data_dir, f)
return None

def get_multiple_dataloader_names(f, names):
if f is not None:
if isinstance(f, list):
if names is None:
names = [f"{i}" for i in range(len(f))]
else:
assert len(names) == len(
f
), "Number of dataloader names must match number of files"
else:
if names is not None:
log.warning(
"Ignoring dataloader names since only one dataloader is provided"
)
return names

self.val_file = get_files(val_file)
self.test_file = get_files(test_file)
self.val_dataloader_names = get_multiple_dataloader_names(
self.val_file, val_dataloader_names
)
self.test_dataloader_names = get_multiple_dataloader_names(
self.test_file, test_dataloader_names
)
self.check_solution = check_solution
if seed is None:
seed = torch.empty((), dtype=torch.int64).random_().item()
Expand Down Expand Up @@ -101,7 +136,14 @@ def dataset(self, batch_size=[], phase="train", filename=None):
"the dataset is fixed and the agent will not be able to explore new states"
)
try:
td = self.load_data(f, batch_size)
if isinstance(f, list):
names = getattr(self, f"{phase}_dataloader_names")
return {
name: TensorDictDataset(self.load_data(_f, batch_size))
for name, _f in zip(names, f)
}
else:
td = self.load_data(f, batch_size)
except FileNotFoundError:
log.error(
f"Provided file name {f} not found. Make sure to provide a file in the right path first or "
Expand Down
31 changes: 23 additions & 8 deletions rl4co/models/rl/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,15 +211,14 @@ def configure_optimizers(self, parameters=None):
"monitor": self.lr_scheduler_monitor,
}

def log_metrics(self, metric_dict: dict, phase: str):
def log_metrics(self, metric_dict: dict, phase: str, dataloader_idx: int = None):
"""Log metrics to logger and progress bar"""
metrics = getattr(self, f"{phase}_metrics")
metrics = {
f"{phase}/{k}": v.mean() if isinstance(v, torch.Tensor) else v
for k, v in metric_dict.items()
if k in metrics
}

log_on_step = self.log_on_step if phase == "train" else False
on_epoch = False if phase == "train" else True
self.log_dict(
Expand All @@ -228,7 +227,7 @@ def log_metrics(self, metric_dict: dict, phase: str):
on_epoch=on_epoch,
prog_bar=True,
sync_dist=True,
add_dataloader_idx=False,
add_dataloader_idx=False, # add names to dataloaders instead
)
return metrics

Expand All @@ -241,19 +240,23 @@ def forward(self, td, **kwargs):
env = kwargs.pop("env")
return self.policy(td, env, **kwargs)

def shared_step(self, batch: Any, batch_idx: int, phase: str):
def shared_step(self, batch: Any, batch_idx: int, phase: str, **kwargs):
"""Shared step between train/val/test. To be implemented in subclass"""
raise NotImplementedError("Shared step is required to implemented in subclass")

def training_step(self, batch: Any, batch_idx: int):
# To use new data every epoch, we need to call reload_dataloaders_every_epoch=True in Trainer
return self.shared_step(batch, batch_idx, phase="train")

def validation_step(self, batch: Any, batch_idx: int):
return self.shared_step(batch, batch_idx, phase="val")
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None):
return self.shared_step(
batch, batch_idx, phase="val", dataloader_idx=dataloader_idx
)

def test_step(self, batch: Any, batch_idx: int):
return self.shared_step(batch, batch_idx, phase="test")
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None):
return self.shared_step(
batch, batch_idx, phase="test", dataloader_idx=dataloader_idx
)

def train_dataloader(self):
return self._dataloader(
Expand All @@ -280,6 +283,18 @@ def wrap_dataset(self, dataset):
return dataset

def _dataloader(self, dataset, batch_size, shuffle=False):
"""Handle both single datasets and list / dict of datasets"""
if isinstance(dataset, list):
return [self._dataloader_single(ds, batch_size, shuffle) for ds in dataset]
elif isinstance(dataset, dict): # we use this by default in RL4COEnvBase
return {
k: self._dataloader_single(ds, batch_size, shuffle)
for k, ds in dataset.items()
}
else:
return self._dataloader_single(dataset, batch_size, shuffle)

def _dataloader_single(self, dataset, batch_size, shuffle=False):
"""The dataloader used by the trainer. This is a wrapper around the dataset with a custom collate_fn
to efficiently handle TensorDicts.
"""
Expand Down
19 changes: 14 additions & 5 deletions rl4co/models/rl/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader

from rl4co.envs.common.base import RL4COEnvBase
Expand Down Expand Up @@ -77,7 +78,9 @@ def __init__(
self.automatic_optimization = False # PPO uses custom optimization routine
self.critic = critic

if isinstance(mini_batch_size, float) and (mini_batch_size <= 0 or mini_batch_size > 1):
if isinstance(mini_batch_size, float) and (
mini_batch_size <= 0 or mini_batch_size > 1
):
default_mini_batch_fraction = 0.25
log.warning(
f"mini_batch_size must be an integer or a float in the range (0, 1], got {mini_batch_size}. Setting mini_batch_size to {default_mini_batch_fraction}."
Expand Down Expand Up @@ -116,7 +119,9 @@ def on_train_epoch_end(self):
if isinstance(sch, torch.optim.lr_scheduler.MultiStepLR):
sch.step()

def shared_step(self, batch: Any, batch_idx: int, phase: str):
def shared_step(
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
):
# Evaluate old actions, log probabilities, and rewards
with torch.no_grad():
td = self.env.reset(batch)
Expand Down Expand Up @@ -147,10 +152,14 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str):

for _ in range(self.ppo_cfg["ppo_epochs"]): # PPO inner epoch, K
for sub_td in dataloader:
ll, entropy = self.policy.evaluate_action(sub_td, action=sub_td["action"])
ll, entropy = self.policy.evaluate_action(
sub_td, action=sub_td["action"]
)

# Compute the ratio of probabilities of new and old actions
ratio = torch.exp(ll.sum(dim=-1) - sub_td["log_prob"]).view(-1, 1) # [batch, 1]
ratio = torch.exp(ll.sum(dim=-1) - sub_td["log_prob"]).view(
-1, 1
) # [batch, 1]

# Compute the advantage
value_pred = self.critic(sub_td) # [batch, 1]
Expand Down Expand Up @@ -204,5 +213,5 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str):
}
)

metrics = self.log_metrics(out, phase)
metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx)
return {"loss": out.get("loss", None), **metrics}
6 changes: 4 additions & 2 deletions rl4co/models/rl/reinforce/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __init__(
log.warning("baseline_kwargs is ignored when baseline is not a string")
self.baseline = baseline

def shared_step(self, batch: Any, batch_idx: int, phase: str):
def shared_step(
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
):
td = self.env.reset(batch)
# Perform forward pass (i.e., constructing solution and computing log-likelihoods)
out = self.policy(td, self.env, phase=phase)
Expand All @@ -57,7 +59,7 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str):
if phase == "train":
out = self.calculate_loss(td, batch, out)

metrics = self.log_metrics(out, phase)
metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx)
return {"loss": out.get("loss", None), **metrics}

def calculate_loss(
Expand Down
6 changes: 4 additions & 2 deletions rl4co/models/zoo/symnco/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __init__(
for phase in ["train", "val", "test"]:
self.set_decode_type_multistart(phase)

def shared_step(self, batch: Any, batch_idx: int, phase: str):
def shared_step(
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
):
td = self.env.reset(batch)
n_aug, n_start = self.num_augment, self.num_starts
n_start = get_num_starts(td) if n_start is None else n_start
Expand Down Expand Up @@ -130,5 +132,5 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str):
}
)

metrics = self.log_metrics(out, phase)
metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx)
return {"loss": out.get("loss", None), **metrics}

0 comments on commit 1c14c3a

Please sign in to comment.