Skip to content

Commit

Permalink
[RLlib] Cleanup examples folder #10: Add custom_rl_module.py exampl…
Browse files Browse the repository at this point in the history
…e script and matching RLModule example class (tiny CNN).. (ray-project#45774)
  • Loading branch information
sven1977 authored Jun 7, 2024
1 parent 641f0fa commit ef54ee5
Show file tree
Hide file tree
Showing 12 changed files with 362 additions and 108 deletions.
16 changes: 9 additions & 7 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2119,7 +2119,6 @@ py_test(

# subdirectory: checkpoints/
# ....................................

py_test(
name = "examples/checkpoints/checkpoint_by_custom_criteria",
main = "examples/checkpoints/checkpoint_by_custom_criteria.py",
Expand Down Expand Up @@ -2283,7 +2282,6 @@ py_test(

# subdirectory: curriculum/
# ....................................

py_test(
name = "examples/curriculum/curriculum_learning",
main = "examples/curriculum/curriculum_learning.py",
Expand All @@ -2295,7 +2293,6 @@ py_test(

# subdirectory: debugging/
# ....................................

#@OldAPIStack
py_test(
name = "examples/debugging/deterministic_training_torch",
Expand All @@ -2308,7 +2305,6 @@ py_test(

# subdirectory: envs/
# ....................................

py_test(
name = "examples/envs/custom_gym_env",
main = "examples/envs/custom_gym_env.py",
Expand Down Expand Up @@ -2449,7 +2445,6 @@ py_test(

# subdirectory: gpus/
# ....................................

py_test(
name = "examples/gpus/fractional_0.5_gpus_per_learner",
main = "examples/gpus/fractional_gpus_per_learner.py",
Expand All @@ -2469,7 +2464,6 @@ py_test(

# subdirectory: hierarchical/
# ....................................

#@OldAPIStack
py_test(
name = "examples/hierarchical/hierarchical_training_tf",
Expand All @@ -2492,7 +2486,6 @@ py_test(

# subdirectory: inference/
# ....................................

#@OldAPIStack
py_test(
name = "examples/inference/policy_inference_after_training_tf",
Expand Down Expand Up @@ -2905,6 +2898,15 @@ py_test(

# subdirectory: rl_modules/
# ....................................
py_test(
name = "examples/rl_modules/custom_rl_module",
main = "examples/rl_modules/custom_rl_module.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/custom_rl_module.py"],
args = ["--enable-new-api-stack", "--stop-iters=3"],
)

#@OldAPIStack @HybridAPIStack
py_test(
name = "examples/rl_modules/classes/mobilenet_rlm_hybrid_api_stack",
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def setup(self):
super().setup()

# If not an inference-only module (e.g., for evaluation), set up the
# parameter names to be removed or renamed when syncing from the state dict
# when synching.
# parameter names to be removed or renamed when syncing from the state dict.
if not self.inference_only:
# Set the expected and unexpected keys for the inference-only module.
self._set_inference_only_state_dict_keys()
Expand Down
31 changes: 14 additions & 17 deletions rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
import json
import pathlib
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Mapping, Any, TYPE_CHECKING, Optional, Type, Dict, Union

import gymnasium as gym
Expand Down Expand Up @@ -203,7 +203,7 @@ class RLModuleConfig:

observation_space: gym.Space = None
action_space: gym.Space = None
model_config_dict: Dict[str, Any] = None
model_config_dict: Dict[str, Any] = field(default_factory=dict)
catalog_class: Type["Catalog"] = None

def get_catalog(self) -> "Catalog":
Expand Down Expand Up @@ -456,22 +456,23 @@ def setup(self):
This is called automatically during the __init__ method of this class,
therefore, the subclass should call super.__init__() in its constructor. This
abstraction can be used to create any component that your RLModule needs.
abstraction can be used to create any components (e.g. NN layers) that your
RLModule needs.
"""
return None

@OverrideToImplementCustomLogic
def get_train_action_dist_cls(self) -> Type[Distribution]:
"""Returns the action distribution class for this RLModule used for training.
This class is used to create action distributions from outputs of the
forward_train method. If the case that no action distribution class is needed,
This class is used to get the correct action distribution class to be used by
the training components. In case that no action distribution class is needed,
this method can return None.
Note that RLlib's distribution classes all implement the `Distribution`
interface. This requires two special methods: `Distribution.from_logits()` and
`Distribution.to_deterministic()`. See the documentation for `Distribution`
for more detail.
`Distribution.to_deterministic()`. See the documentation of the
:py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
"""
raise NotImplementedError

Expand All @@ -485,8 +486,8 @@ def get_exploration_action_dist_cls(self) -> Type[Distribution]:
Note that RLlib's distribution classes all implement the `Distribution`
interface. This requires two special methods: `Distribution.from_logits()` and
`Distribution.to_deterministic()`. See the documentation for `Distribution`
for more detail.
`Distribution.to_deterministic()`. See the documentation of the
:py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
"""
raise NotImplementedError

Expand All @@ -500,8 +501,8 @@ def get_inference_action_dist_cls(self) -> Type[Distribution]:
Note that RLlib's distribution classes all implement the `Distribution`
interface. This requires two special methods: `Distribution.from_logits()` and
`Distribution.to_deterministic()`. See the documentation for `Distribution`
for more detail.
`Distribution.to_deterministic()`. See the documentation of the
:py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
"""
raise NotImplementedError

Expand Down Expand Up @@ -596,9 +597,7 @@ def output_specs_inference(self) -> SpecType:
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
"""
# TODO (sven): We should probably change this to [ACTION_DIST_INPUTS], b/c this
# is what most algos will do.
return {"action_dist": Distribution}
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_exploration(self) -> SpecType:
Expand All @@ -609,9 +608,7 @@ def output_specs_exploration(self) -> SpecType:
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
"""
# TODO (sven): We should probably change this to [ACTION_DIST_INPUTS], b/c this
# is what most algos will do.
return {"action_dist": Distribution}
return [Columns.ACTION_DIST_INPUTS]

def output_specs_train(self) -> SpecType:
"""Returns the output specs of the forward_train method."""
Expand Down
82 changes: 41 additions & 41 deletions rllib/core/rl_module/torch/torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,47 +21,6 @@
torch, nn = try_import_torch()


def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConfig):
"""A wrapper that compiles the forward methods of a TorchRLModule."""

# TODO(Artur): Remove this once our requirements enforce torch >= 2.0.0
# Check if torch framework supports torch.compile.
if (
torch is not None
and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
):
raise ValueError("torch.compile is only supported from torch 2.0.0")

compiled_forward_train = torch.compile(
rl_module._forward_train,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs
)

rl_module._forward_train = compiled_forward_train

compiled_forward_inference = torch.compile(
rl_module._forward_inference,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs
)

rl_module._forward_inference = compiled_forward_inference

compiled_forward_exploration = torch.compile(
rl_module._forward_exploration,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs
)

rl_module._forward_exploration = compiled_forward_exploration

return rl_module


class TorchRLModule(nn.Module, RLModule):
"""A base class for RLlib PyTorch RLModules.
Expand Down Expand Up @@ -234,3 +193,44 @@ class TorchDDPRLModuleWithTargetNetworksInterface(
@override(RLModuleWithTargetNetworksInterface)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
return self.module.get_target_network_pairs()


def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConfig):
"""A wrapper that compiles the forward methods of a TorchRLModule."""

# TODO(Artur): Remove this once our requirements enforce torch >= 2.0.0
# Check if torch framework supports torch.compile.
if (
torch is not None
and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
):
raise ValueError("torch.compile is only supported from torch 2.0.0")

compiled_forward_train = torch.compile(
rl_module._forward_train,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs,
)

rl_module._forward_train = compiled_forward_train

compiled_forward_inference = torch.compile(
rl_module._forward_inference,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs,
)

rl_module._forward_inference = compiled_forward_inference

compiled_forward_exploration = torch.compile(
rl_module._forward_exploration,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs,
)

rl_module._forward_exploration = compiled_forward_exploration

return rl_module
7 changes: 2 additions & 5 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,9 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
try:
module_spec: SingleAgentRLModuleSpec = self.config.rl_module_spec
module_spec.observation_space = self._env_to_module.observation_space
# TODO (simon): The `gym.Wrapper` for `gym.vector.VectorEnv` should
# actually hold the spaces for a single env, but for boxes the
# shape is (1, 1) which brings a problem with the action dists.
# shape=(1,) is expected.
module_spec.action_space = self.env.envs[0].action_space
module_spec.model_config_dict = self.config.model_config
if module_spec.model_config_dict is None:
module_spec.model_config_dict = self.config.model_config
# Only load a light version of the module, if available. This is useful
# if the the module has target or critic networks not needed in sampling
# or inference.
Expand Down
6 changes: 0 additions & 6 deletions rllib/examples/rl_modules/action_masking_rlm.py

This file was deleted.

Loading

0 comments on commit ef54ee5

Please sign in to comment.