Skip to content

Commit

Permalink
polish(nyz): complete sac/cql comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Oct 17, 2023
1 parent 33ea61b commit b654d07
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 35 deletions.
83 changes: 71 additions & 12 deletions ding/policy/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,22 @@ class CQLPolicy(SACPolicy):
def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init q and policy's optimizers, algorithm config, main and target models.
Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \
contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange and \
with_q_entropy, main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \
target is also initialized here.
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
Expand Down Expand Up @@ -241,14 +255,30 @@ def _init_learn(self) -> None:

self._forward_learn_cnt = 0

def _forward_learn(self, data: dict) -> Dict[str, Any]:
def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Overview:
Forward and backward function of learn mode.
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss, action, priority.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
dimension by some utility functions such as ``default_preprocess_learn``. \
For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
loss_dict = {}
data = default_preprocess_learn(
Expand Down Expand Up @@ -509,8 +539,20 @@ class DiscreteCQLPolicy(QRDQNPolicy):
def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config, main and target models.
Initialize the learn mode of policy, including related attributes and modules. For DiscreteCQL, it mainly \
contains the optimizer, algorithm-specific arguments such as gamma, nstep and min_q_weight, main and \
target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
self._min_q_weight = self._cfg.learn.min_q_weight
self._priority = self._cfg.priority
Expand All @@ -532,14 +574,31 @@ def _init_learn(self) -> None:
self._learn_model.reset()
self._target_model.reset()

def _forward_learn(self, data: dict) -> Dict[str, Any]:
def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Overview:
Forward and backward function of learn mode.
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss, action, priority.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
dimension by some utility functions such as ``default_preprocess_learn``. \
For DiscreteCQL, each element in list is a dict containing at least the following keys: ``obs``, \
``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight`` \
and ``value_gamma`` for nstep return computation.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
data = default_preprocess_learn(
data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
Expand Down
2 changes: 1 addition & 1 deletion ding/policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def _init_eval(self) -> None:
self._eval_model = model_wrap(self._eval_model, wrapper_name='hybrid_argmax_sample')
self._eval_model.reset()

def _forward_eval(self, data: dict) -> dict:
def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
"""
Overview:
Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
Expand Down
Loading

0 comments on commit b654d07

Please sign in to comment.