diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index d5bc3583c8..0b944852d8 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -81,11 +81,17 @@ def default_config(cls: type) -> EasyDict: ) total_field = set(['learn', 'collect', 'eval']) config = dict( + # (bool) Whether the learning policy is the same as the collecting data policy (on-policy). on_policy=False, + # (bool) Whether to use cuda in policy. cuda=False, + # (bool) Whether to use data parallel multi-gpu mode in policy. multi_gpu=False, + # (bool) Whether to synchronize update the model parameters after allreduce the gradients of model parameters. bp_update_sync=True, + # (bool) Whether to enable infinite trajectory length in data collecting. traj_len_inf=False, + # neural network model config model=dict(), ) @@ -132,6 +138,7 @@ def __init__( multi_gpu = self._cfg.multi_gpu self._rank = get_rank() if multi_gpu else 0 if self._cuda: + # model.cuda() is an in-place operation. model.cuda() if multi_gpu: bp_update_sync = self._cfg.bp_update_sync @@ -140,6 +147,7 @@ def __init__( else: self._rank = 0 if self._cuda: + # model.cuda() is an in-place operation. model.cuda() self._model = model self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu' @@ -148,6 +156,7 @@ def __init__( self._rank = 0 self._device = 'cpu' + # call the initialization method of different modes, such as ``_init_learn``, ``_init_collect``, ``_init_eval`` for field in self._enable_field: getattr(self, '_init_' + field)() @@ -276,6 +285,20 @@ def _init_eval(self) -> None: @property def learn_mode(self) -> 'Policy.learn_function': # noqa + """ + Overview: + Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple \ + to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ + subclass can override the interfaces to customize its own learn mode. + Returns: + - interfaces (:obj:`Policy.learn_function`): The interfaces of learn mode of policy, it is a namedtuple \ + whose values of distinct fields are different internal methods. + Examples: + >>> policy = Policy(cfg, model) + >>> policy_learn = policy.learn_mode + >>> train_output = policy_learn.forward(data) + >>> state_dict = policy_learn.state_dict() + """ return Policy.learn_function( self._forward_learn, self._reset_learn, @@ -289,6 +312,21 @@ def learn_mode(self) -> 'Policy.learn_function': # noqa @property def collect_mode(self) -> 'Policy.collect_function': # noqa + """ + Overview: + Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple \ + to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ + subclass can override the interfaces to customize its own collect mode. + Returns: + - interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a namedtuple \ + whose values of distinct fields are different internal methods. + Examples: + >>> policy = Policy(cfg, model) + >>> policy_collect = policy.collect_mode + >>> obs = env_manager.ready_obs + >>> inference_output = policy_collect.forward(obs) + >>> next_obs, rew, done, info = env_manager.step(inference_output.action) + """ return Policy.collect_function( self._forward_collect, self._process_transition, @@ -302,6 +340,21 @@ def collect_mode(self) -> 'Policy.collect_function': # noqa @property def eval_mode(self) -> 'Policy.eval_function': # noqa + """ + Overview: + Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple \ + to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ + subclass can override the interfaces to customize its own eval mode. + Returns: + - interfaces (:obj:`Policy.eval_function`): The interfaces of eval mode of policy, it is a namedtuple \ + whose values of distinct fields are different internal methods. + Examples: + >>> policy = Policy(cfg, model) + >>> policy_eval = policy.eval_mode + >>> obs = env_manager.ready_obs + >>> inference_output = policy_eval.forward(obs) + >>> next_obs, rew, done, info = env_manager.step(inference_output.action) + """ return Policy.eval_function( self._forward_eval, self._reset_eval, @@ -393,7 +446,26 @@ def default_model(self) -> Tuple[str, List[str]]: # *************************************** learn function ************************************ @abstractmethod - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss value, policy entropy, q value, priority, \ + and so on. This method is left to be implemented by the subclass, and more arguments can be added in \ + ``data`` item if necessary. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including a batch of training \ + samples. The key of the dict is the name of data items and the value is the corresponding data. \ + Usually, the date item is a list of data, and the first dimension is the batch dimension, then in \ + the ``_forward_learn`` method, the data item should be stacked in the batch dimension by some utility \ + methods such as ``default_preprocess_learn``. + Returns: + - output (:obj:`Dict[int, Any]`): The training information of policy forward, including some metrics for \ + monitoring training such as loss, priority, q value, policy entropy, and some data for next step \ + training such as priority. Note the output data item should be Python native scalar rather than \ + PyTorch tensor, which is convenient for the outside to use. + """ raise NotImplementedError # don't need to implement _reset_learn method by force @@ -474,7 +546,7 @@ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: part if necessary. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ @@ -598,7 +670,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: This method is left to be implemented by the subclass. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ key of the dict is the same as the input data, i.e. environment id. @@ -669,6 +741,19 @@ class CommandModePolicy(Policy): @property def command_mode(self) -> 'Policy.command_function': # noqa + """ + Overview: + Return the interfaces of command mode of policy, which is used to train the model. Here we use namedtuple \ + to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ + subclass can override the interfaces to customize its own command mode. + Returns: + - interfaces (:obj:`Policy.command_function`): The interfaces of command mode, it is a namedtuple \ + whose values of distinct fields are different internal methods. + Examples: + >>> policy = CommandModePolicy(cfg, model) + >>> policy_command = policy.command_mode + >>> settings = policy_command.get_setting_learn(command_info) + """ return CommandModePolicy.command_function( self._get_setting_learn, self._get_setting_collect, self._get_setting_eval ) diff --git a/ding/policy/common_utils.py b/ding/policy/common_utils.py index 0273b68704..d04abf3658 100644 --- a/ding/policy/common_utils.py +++ b/ding/policy/common_utils.py @@ -1,4 +1,4 @@ -from typing import List, Any +from typing import List, Any, Dict import torch import numpy as np import treetensor.torch as ttorch @@ -12,14 +12,24 @@ def default_preprocess_learn( use_priority: bool = False, use_nstep: bool = False, ignore_done: bool = False, -) -> dict: +) -> Dict[str, torch.Tensor]: """ Overview: - Default data pre-processing in policy's ``_forward_learn``. + Default data pre-processing in policy's ``_forward_learn`` method, including stacking batch data, preprocess \ + ignore done, nstep and priority IS weight. + Arguments: + - data (:obj:`List[Any]`): The list of a training batch samples, each sample is a dict of PyTorch Tensor. + - use_priority_IS_weight (:obj:`bool`): Whether to use priority IS weight correction, if True, this function \ + will set the weight of each sample to the priority IS weight. + - use_priority (:obj:`bool`): Whether to use priority, if True, this function will set the priority IS weight. + - use_nstep (:obj:`bool`): Whether to use nstep TD error, if True, this function will reshape the reward. + - ignore_done (:obj:`bool`): Whether to ignore done, if True, this function will set the done to 0. + Returns: + - data (:obj:`Dict[str, torch.Tensor]`): The preprocessed dict data whose values can be directly used for \ + the following model forward and loss computation. """ # data preprocess - if data[0]['action'].dtype in [np.int8, np.int16, np.int32, np.int64, torch.int8, torch.int16, torch.int32, - torch.int64]: + if data[0]['action'].dtype in [np.int64, torch.int64]: data = default_collate(data, cat_1dim=True) # for discrete action else: data = default_collate(data, cat_1dim=False) # for continuous action diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index 536e9a7dbc..e7c5595789 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -98,7 +98,7 @@ class DQNPolicy(Policy): # (int) The number of step for calculating target q_value. nstep=1, model=dict( - #(list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer. + # (list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer. encoder_hidden_size_list=[128, 128, 64], ), # learn_mode config @@ -335,7 +335,7 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: exploration, i.e., classic epsilon-greedy exploration strategy. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. - eps (:obj:`float`): The epsilon value for exploration. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ @@ -423,7 +423,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: action to interact with the envs. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ key of the dict is the same as the input data, i.e. environment id. diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 849e2c86ab..e717ea23e2 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -329,7 +329,7 @@ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: data, such as the action to interact with the envs. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \ @@ -472,7 +472,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: exploitation. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ key of the dict is the same as the input data, i.e. environment id. @@ -682,7 +682,7 @@ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: data, such as the action to interact with the envs. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ other necessary data (action logit) for learn mode defined in ``self._process_transition`` \ @@ -785,7 +785,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: exploitation. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ key of the dict is the same as the input data, i.e. environment id. @@ -1171,7 +1171,7 @@ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: data, such as the action to interact with the envs. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \ @@ -1314,7 +1314,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: exploitation. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ key of the dict is the same as the input data, i.e. environment id. diff --git a/ding/policy/r2d2.py b/ding/policy/r2d2.py index 70718d0f6e..f8fe9a21d4 100644 --- a/ding/policy/r2d2.py +++ b/ding/policy/r2d2.py @@ -444,7 +444,7 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: exploration, i.e., classic epsilon-greedy exploration strategy. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. - eps (:obj:`float`): The epsilon value for exploration. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ @@ -556,7 +556,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: q_value is the highest. Arguments: - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ - key of the dict is environment id and the value if the corresponding data of the env. + key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ key of the dict is the same as the input data, i.e. environment id.