Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V0.3.0 #100

Merged
merged 14 commits into from
Nov 8, 2023
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ jobs:
fail-fast: true
max-parallel: 15
matrix:
# os: [ubuntu-latest, macos-latest, windows-latest, macos-13-xlarge]
# For Apple Silicon: https://github.com/actions/runner-images/issues/8439
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.8', '3.9', '3.10', '3.11']
defaults:
run:
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion configs/env/cvrp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.cvrp.CVRPEnv
_target_: rl4co.envs.CVRPEnv
name: cvrp

num_loc: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.tsp.TSPEnv
_target_: rl4co.envs.TSPEnv
name: tsp

num_loc: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/dpp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.dpp.DPPEnv
_target_: rl4co.envs.DPPEnv
name: dpp

max_decaps: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/mdpp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.mdpp.MDPPEnv
_target_: rl4co.envs.MDPPEnv
name: mdpp

max_decaps: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/mtsp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.mtsp.MTSPEnv
_target_: rl4co.envs.MTSPEnv
name: mtsp

num_loc: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/op.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.op.OPEnv
_target_: rl4co.envs.OPEnv
name: op

num_loc: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/pctsp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.pctsp.PCTSPEnv
_target_: rl4co.envs.PCTSPEnv
name: pctsp

num_loc: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/pdp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.pdp.PDPEnv
_target_: rl4co.envs.PDPEnv
name: pdp

num_loc: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/sdvrp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.sdvrp.SDVRPEnv
_target_: rl4co.envs.SDVRPEnv
name: sdvrp

num_loc: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/spctsp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.spctsp.SPCTSPEnv
_target_: rl4co.envs.SPCTSPEnv
name: spctsp

num_loc: 20
Expand Down
2 changes: 1 addition & 1 deletion configs/env/tsp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: rl4co.envs.tsp.TSPEnv
_target_: rl4co.envs.TSPEnv

name: tsp

Expand Down
3 changes: 0 additions & 3 deletions configs/experiment/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ defaults:
# that are automatically generated with seed following Kool et al. (2019).
env:
num_loc: 50
data_dir: ${paths.root_dir}/data/tsp
val_file: tsp${env.num_loc}_val_seed4321.npz
test_file: tsp${env.num_loc}_test_seed1234.npz

# Logging: we use Wandb in this case
logger:
Expand Down
9 changes: 1 addition & 8 deletions configs/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,9 @@ _target_: rl4co.utils.trainer.RL4COTrainer
default_root_dir: ${paths.output_dir}

gradient_clip_val: 1.0
accelerator: "gpu"
accelerator: "auto"
precision: "16-mixed"

# Fast distributed training: comment out to use on single GPU
# devices: 1 # change number of devices
strategy:
_target_: lightning.pytorch.strategies.DDPStrategy
find_unused_parameters: True
gradient_as_bucket_view: True

# perform a validation loop every N training epochs
check_val_every_n_epoch: 1

Expand Down
5 changes: 2 additions & 3 deletions docs/_theme/rl4co/extensions/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
# limitations under the License.
from docutils import nodes
from docutils.statemachine import StringList
from sphinx.util.docutils import SphinxDirective

from pt_lightning_sphinx_theme.extensions.pytorch_tutorials import (
cardnode,
CustomCalloutItemDirective,
CustomCardItemDirective,
DisplayItemDirective,
LikeButtonWithTitle,
ReactGreeter,
SlackButton,
TwoColumns,
cardnode,
)
from sphinx.util.docutils import SphinxDirective


class tutoriallistnode(nodes.General, nodes.Element):
Expand Down
29 changes: 23 additions & 6 deletions docs/_theme/rl4co/extensions/pytorch_tutorials.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@
from docutils import nodes
from docutils.parsers.rst import Directive, directives
from docutils.statemachine import StringList
from sphinx.util.docutils import SphinxDirective

from pt_lightning_sphinx_theme.extensions.react import get_react_component_rst
from sphinx.util.docutils import SphinxDirective

try:
FileNotFoundError
Expand Down Expand Up @@ -272,11 +271,23 @@ def run(self):

image_class = ""
if "image_center" in self.options:
image = "<img src='" + self.options["image_center"] + "' style=height:" + image_height + " >"
image = (
"<img src='"
+ self.options["image_center"]
+ "' style=height:"
+ image_height
+ " >"
)
image_class = "image-center"

elif "image_right" in self.options:
image = "<img src='" + self.options["image_right"] + "' style=height:" + image_height + " >"
image = (
"<img src='"
+ self.options["image_right"]
+ "' style=height:"
+ image_height
+ " >"
)
image_class = "image-right"
else:
image = ""
Expand Down Expand Up @@ -371,7 +382,11 @@ def run(self):
raise
# return []
callout_rst = get_react_component_rst(
"LikeButtonWithTitle", width=width, margin=margin, title=title, padding=padding
"LikeButtonWithTitle",
width=width,
margin=margin,
title=title,
padding=padding,
)
callout_list = StringList(callout_rst.split("\n"))
callout = nodes.paragraph()
Expand Down Expand Up @@ -427,7 +442,9 @@ def run(self):
print(e)
raise
return []
callout_rst = SLACK_TEMPLATE.format(align=align, title=title, margin=margin, width=width)
callout_rst = SLACK_TEMPLATE.format(
align=align, title=title, margin=margin, width=width
)
callout_list = StringList(callout_rst.split("\n"))
callout = nodes.paragraph()
self.state.nested_parse(callout_list, self.content_offset, callout)
Expand Down
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ dynamic = ["version"]

license = {file = "LICENSE"}

# TODO: allow new Python versions https://github.com/kaist-silab/rl4co/issues/95
requires-python = ">=3.8, <3.11" # https://github.com/kaist-silab/rl4co/issues/90
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
Expand All @@ -41,10 +41,9 @@ classifiers = [

# TODO: allow new TorchRL / TensorDict versions https://github.com/kaist-silab/rl4co/issues/95
dependencies = [
"torch>=2.0.0,<2.1.0", # Possibly TorchRL problem on Windows with older version
"torchrl==0.1.1",
"tensordict==0.1.1",
"lightning>=2.0.5",
"torchrl>=0.2.0",
"tensordict>=0.2.0",
"lightning>=2.1.0",
"hydra-core",
"hydra-colorlog",
"omegaconf",
Expand Down
2 changes: 1 addition & 1 deletion rl4co/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.4.dev1"
__version__ = "0.3.0dev0"
69 changes: 54 additions & 15 deletions rl4co/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,45 @@

import torch

from tensordict.tensordict import TensorDict
from torch.utils.data import Dataset


class TensorDictDataset(Dataset):
"""Dataset compatible with TensorDicts.
Uses more CPU and has similar performance in loading to list comprehension, but is faster in instantiation
than :class:`TensorDictDatasetList` (more than 10x faster).
"""

def __init__(self, td: TensorDict):
self.data = td

def __len__(self):
return len(self.data)

def __getitems__(self, index):
# Tricks:
# - batched data loading with `__getitems__` for faster loading
# - avoid directly indexing TensorDicts for faster loading
return TensorDict(
{key: item[index] for key, item in self.data.items()},
batch_size=torch.Size([len(index)]),
_run_checks=False, # faster this way
)

def add_key(self, key, value):
return self.data.update({key: value}) # native method


def tensordict_collate_fn(x):
"""Equivalent to collating with `lambda x: x`"""
return x


class TensorDictDatasetList(Dataset):
"""Dataset compatible with TensorDicts.
It is better to "disassemble" the TensorDict into a list of dicts.
See :class:`tensordict_collate_fn` for more details.
See :class:`tensordict_collate_fn_list` for more details.

Note:
Check out the issue on tensordict for more details:
Expand All @@ -16,21 +48,24 @@ class TensorDictDataset(Dataset):
but uses > 3x more CPU.
"""

def __init__(self, data: TensorDict):
def __init__(self, td: TensorDict):
self.data_len = td.batch_size[0]
self.data = [
{key: value[i] for key, value in data.items()} for i in range(data.shape[0])
{key: value[i] for key, value in td.items()} for i in range(self.data_len)
]

def __len__(self):
return len(self.data)
return self.data_len

def __getitem__(self, idx):
return self.data[idx]

def add_key(self, key, value):
return ExtraKeyDataset(self, value, key_name=key)

def tensordict_collate_fn(batch):
"""Collate function compatible with TensorDicts.
Reassemble the list of dicts into a TensorDict; seems to be way more efficient than using a TensorDictDataset.

def tensordict_collate_fn_list(batch):
"""Collate function compatible with TensorDicts that reassembles a list of dicts.

Note:
Check out the issue on tensordict for more details:
Expand All @@ -40,24 +75,28 @@ def tensordict_collate_fn(batch):
"""
return TensorDict(
{key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
batch_size=len(batch),
batch_size=torch.Size([len(batch)]),
device=batch[0].device,
_run_checks=False,
)


class ExtraKeyDataset(Dataset):
class ExtraKeyDataset(TensorDictDatasetList):
"""Dataset that includes an extra key to add to the data dict.
This is useful for adding a REINFORCE baseline reward to the data dict.
Note that this is faster to instantiate than using list comprehension.
"""

def __init__(self, dataset: TensorDictDataset, extra: torch.Tensor):
def __init__(
self, dataset: TensorDictDatasetList, extra: torch.Tensor, key_name="extra"
):
self.data_len = len(dataset)
assert self.data_len == len(extra), "Data and extra must be same length"
self.data = dataset.data
self.extra = extra
assert len(self.data) == len(self.extra), "Data and extra must be same length"

def __len__(self):
return len(self.data)
self.key_name = key_name

def __getitem__(self, idx):
data = self.data[idx]
data["extra"] = self.extra[idx]
data[self.key_name] = self.extra[idx]
return data
34 changes: 20 additions & 14 deletions rl4co/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
# Base environment
# Main Environments
from rl4co.envs.atsp import ATSPEnv
from rl4co.envs.common.base import RL4COEnvBase
from rl4co.envs.cvrp import CVRPEnv
from rl4co.envs.dpp import DPPEnv
from rl4co.envs.ffsp import FFSPEnv
from rl4co.envs.mdpp import MDPPEnv
from rl4co.envs.mtsp import MTSPEnv
from rl4co.envs.op import OPEnv
from rl4co.envs.pctsp import PCTSPEnv
from rl4co.envs.pdp import PDPEnv
from rl4co.envs.sdvrp import SDVRPEnv
from rl4co.envs.smtwtp import SMTWTPEnv
from rl4co.envs.spctsp import SPCTSPEnv
from rl4co.envs.tsp import TSPEnv

# EDA
from rl4co.envs.eda import DPPEnv, MDPPEnv

# Routing
from rl4co.envs.routing import (
ATSPEnv,
CVRPEnv,
MTSPEnv,
OPEnv,
PCTSPEnv,
PDPEnv,
SDVRPEnv,
SPCTSPEnv,
TSPEnv,
)

# Scheduling
from rl4co.envs.scheduling import FFSPEnv, SMTWTPEnv

# Register environments
ENV_REGISTRY = {
"atsp": ATSPEnv,
"cvrp": CVRPEnv,
"dpp": DPPEnv,
"ffsp": FFSPEnv,
"mdpp": MDPPEnv,
"mtsp": MTSPEnv,
"op": OPEnv,
Expand Down
Loading