Skip to content

Commit

Permalink
polish(nyz): polish policy mode comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Oct 7, 2023
1 parent ec401c1 commit 203c646
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 19 deletions.
91 changes: 88 additions & 3 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,17 @@ def default_config(cls: type) -> EasyDict:
)
total_field = set(['learn', 'collect', 'eval'])
config = dict(
# (bool) Whether the learning policy is the same as the collecting data policy (on-policy).
on_policy=False,
# (bool) Whether to use cuda in policy.
cuda=False,
# (bool) Whether to use data parallel multi-gpu mode in policy.
multi_gpu=False,
# (bool) Whether to synchronize update the model parameters after allreduce the gradients of model parameters.
bp_update_sync=True,
# (bool) Whether to enable infinite trajectory length in data collecting.
traj_len_inf=False,
# neural network model config
model=dict(),
)

Expand Down Expand Up @@ -132,6 +138,7 @@ def __init__(
multi_gpu = self._cfg.multi_gpu
self._rank = get_rank() if multi_gpu else 0
if self._cuda:
# model.cuda() is an in-place operation.
model.cuda()
if multi_gpu:
bp_update_sync = self._cfg.bp_update_sync
Expand All @@ -140,6 +147,7 @@ def __init__(
else:
self._rank = 0
if self._cuda:
# model.cuda() is an in-place operation.
model.cuda()
self._model = model
self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu'
Expand All @@ -148,6 +156,7 @@ def __init__(
self._rank = 0
self._device = 'cpu'

# call the initialization method of different modes, such as ``_init_learn``, ``_init_collect``, ``_init_eval``
for field in self._enable_field:
getattr(self, '_init_' + field)()

Expand Down Expand Up @@ -276,6 +285,20 @@ def _init_eval(self) -> None:

@property
def learn_mode(self) -> 'Policy.learn_function': # noqa
"""
Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple \
to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
subclass can override the interfaces to customize its own learn mode.
Returns:
- interfaces (:obj:`Policy.learn_function`): The interfaces of learn mode of policy, it is a namedtuple \
whose values of distinct fields are different internal methods.
Examples:
>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
"""
return Policy.learn_function(
self._forward_learn,
self._reset_learn,
Expand All @@ -289,6 +312,21 @@ def learn_mode(self) -> 'Policy.learn_function': # noqa

@property
def collect_mode(self) -> 'Policy.collect_function': # noqa
"""
Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple \
to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
subclass can override the interfaces to customize its own collect mode.
Returns:
- interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a namedtuple \
whose values of distinct fields are different internal methods.
Examples:
>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
"""
return Policy.collect_function(
self._forward_collect,
self._process_transition,
Expand All @@ -302,6 +340,21 @@ def collect_mode(self) -> 'Policy.collect_function': # noqa

@property
def eval_mode(self) -> 'Policy.eval_function': # noqa
"""
Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple \
to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
subclass can override the interfaces to customize its own eval mode.
Returns:
- interfaces (:obj:`Policy.eval_function`): The interfaces of eval mode of policy, it is a namedtuple \
whose values of distinct fields are different internal methods.
Examples:
>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
"""
return Policy.eval_function(
self._forward_eval,
self._reset_eval,
Expand Down Expand Up @@ -393,7 +446,26 @@ def default_model(self) -> Tuple[str, List[str]]:
# *************************************** learn function ************************************

@abstractmethod
def _forward_learn(self, data: dict) -> Dict[str, Any]:
def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss value, policy entropy, q value, priority, \
and so on. This method is left to be implemented by the subclass, and more arguments can be added in \
``data`` item if necessary.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including a batch of training \
samples. The key of the dict is the name of data items and the value is the corresponding data. \
Usually, the date item is a list of data, and the first dimension is the batch dimension, then in \
the ``_forward_learn`` method, the data item should be stacked in the batch dimension by some utility \
methods such as ``default_preprocess_learn``.
Returns:
- output (:obj:`Dict[int, Any]`): The training information of policy forward, including some metrics for \
monitoring training such as loss, priority, q value, policy entropy, and some data for next step \
training such as priority. Note the output data item should be Python native scalar rather than \
PyTorch tensor, which is convenient for the outside to use.
"""
raise NotImplementedError

# don't need to implement _reset_learn method by force
Expand Down Expand Up @@ -474,7 +546,7 @@ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
part if necessary.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \
Expand Down Expand Up @@ -598,7 +670,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
This method is left to be implemented by the subclass.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
Expand Down Expand Up @@ -669,6 +741,19 @@ class CommandModePolicy(Policy):

@property
def command_mode(self) -> 'Policy.command_function': # noqa
"""
Overview:
Return the interfaces of command mode of policy, which is used to train the model. Here we use namedtuple \
to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
subclass can override the interfaces to customize its own command mode.
Returns:
- interfaces (:obj:`Policy.command_function`): The interfaces of command mode, it is a namedtuple \
whose values of distinct fields are different internal methods.
Examples:
>>> policy = CommandModePolicy(cfg, model)
>>> policy_command = policy.command_mode
>>> settings = policy_command.get_setting_learn(command_info)
"""
return CommandModePolicy.command_function(
self._get_setting_learn, self._get_setting_collect, self._get_setting_eval
)
Expand Down
20 changes: 15 additions & 5 deletions ding/policy/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Any
from typing import List, Any, Dict
import torch
import numpy as np
import treetensor.torch as ttorch
Expand All @@ -12,14 +12,24 @@ def default_preprocess_learn(
use_priority: bool = False,
use_nstep: bool = False,
ignore_done: bool = False,
) -> dict:
) -> Dict[str, torch.Tensor]:
"""
Overview:
Default data pre-processing in policy's ``_forward_learn``.
Default data pre-processing in policy's ``_forward_learn`` method, including stacking batch data, preprocess \
ignore done, nstep and priority IS weight.
Arguments:
- data (:obj:`List[Any]`): The list of a training batch samples, each sample is a dict of PyTorch Tensor.
- use_priority_IS_weight (:obj:`bool`): Whether to use priority IS weight correction, if True, this function \
will set the weight of each sample to the priority IS weight.
- use_priority (:obj:`bool`): Whether to use priority, if True, this function will set the priority IS weight.
- use_nstep (:obj:`bool`): Whether to use nstep TD error, if True, this function will reshape the reward.
- ignore_done (:obj:`bool`): Whether to ignore done, if True, this function will set the done to 0.
Returns:
- data (:obj:`Dict[str, torch.Tensor]`): The preprocessed dict data whose values can be directly used for \
the following model forward and loss computation.
"""
# data preprocess
if data[0]['action'].dtype in [np.int8, np.int16, np.int32, np.int64, torch.int8, torch.int16, torch.int32,
torch.int64]:
if data[0]['action'].dtype in [np.int64, torch.int64]:
data = default_collate(data, cat_1dim=True) # for discrete action
else:
data = default_collate(data, cat_1dim=False) # for continuous action
Expand Down
6 changes: 3 additions & 3 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DQNPolicy(Policy):
# (int) The number of step for calculating target q_value.
nstep=1,
model=dict(
#(list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer.
# (list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer.
encoder_hidden_size_list=[128, 128, 64],
),
# learn_mode config
Expand Down Expand Up @@ -335,7 +335,7 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
exploration, i.e., classic epsilon-greedy exploration strategy.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
- eps (:obj:`float`): The epsilon value for exploration.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
Expand Down Expand Up @@ -423,7 +423,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
action to interact with the envs.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
Expand Down
12 changes: 6 additions & 6 deletions ding/policy/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]:
data, such as the action to interact with the envs.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \
Expand Down Expand Up @@ -472,7 +472,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
exploitation.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
Expand Down Expand Up @@ -682,7 +682,7 @@ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]:
data, such as the action to interact with the envs.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
other necessary data (action logit) for learn mode defined in ``self._process_transition`` \
Expand Down Expand Up @@ -785,7 +785,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
exploitation.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]:
data, such as the action to interact with the envs.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \
Expand Down Expand Up @@ -1314,7 +1314,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
exploitation.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
Expand Down
4 changes: 2 additions & 2 deletions ding/policy/r2d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
exploration, i.e., classic epsilon-greedy exploration strategy.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
- eps (:obj:`float`): The epsilon value for exploration.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
Expand Down Expand Up @@ -556,7 +556,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
q_value is the highest.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value if the corresponding data of the env.
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
Expand Down

0 comments on commit 203c646

Please sign in to comment.