Skip to content

Commit

Permalink
polish(nyz): complete dqn comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Oct 10, 2023
1 parent d85fab0 commit e10dc80
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 49 deletions.
17 changes: 8 additions & 9 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional, List, Dict, Any, Tuple, Union
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import Optional, List, Dict, Any, Tuple, Union
from easydict import EasyDict

import torch
import copy
from easydict import EasyDict
import torch

from ding.model import create_model
from ding.utils import import_module, allreduce, broadcast, get_rank, allreduce_async, synchronize, deep_merge_dicts, \
Expand Down Expand Up @@ -446,7 +446,7 @@ def default_model(self) -> Tuple[str, List[str]]:
# *************************************** learn function ************************************

@abstractmethod
def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means \
Expand All @@ -455,11 +455,10 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
and so on. This method is left to be implemented by the subclass, and more arguments can be added in \
``data`` item if necessary.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including a batch of training \
samples. The key of the dict is the name of data items and the value is the corresponding data. \
Usually, the date item is a list of data, and the first dimension is the batch dimension, then in \
the ``_forward_learn`` method, the data item should be stacked in the batch dimension by some utility \
methods such as ``default_preprocess_learn``.
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, in the ``_forward_learn`` method, data should be stacked in \
the batch dimension by some utility functions such as ``default_preprocess_learn``.
Returns:
- output (:obj:`Dict[int, Any]`): The training information of policy forward, including some metrics for \
monitoring training such as loss, priority, q value, policy entropy, and some data for next step \
Expand Down
135 changes: 96 additions & 39 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,20 @@ def default_model(self) -> Tuple[str, List[str]]:
def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \
and target model.
Initialize the learn mode of policy, including related attributes and modules. For DQN, it mainly contains \
optimizer, algorithm-specific arguments such as nstep and gamma, main and target model.
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
Expand Down Expand Up @@ -207,23 +219,36 @@ def _init_learn(self) -> None:
self._learn_model.reset()
self._target_model.reset()

def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Overview:
Forward computation graph of learn mode(updating policy).
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss, q value, priority.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
dimension by some utility functions such as ``default_preprocess_learn``. \
For DQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
and ``value_gamma``.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``, ``priority``
- optional: ``action_distribution``
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
.. note::
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
"""
# Data preprocessing operations, such as stack data, cpu to cuda device
data = default_preprocess_learn(
data,
use_priority=self._priority,
Expand All @@ -233,9 +258,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
Expand All @@ -252,18 +275,14 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
value_gamma = data.get('value_gamma')
loss, td_error_per_sample = q_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma)

# ====================
# Q-learning update
# ====================
# Update network parameters
self._optimizer.zero_grad()
loss.backward()
if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()

# =============
# after update
# =============
# Postprocessing operations, such as updating target model, return logged values and priority.
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
Expand Down Expand Up @@ -317,8 +336,18 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
def _init_collect(self) -> None:
"""
Overview:
Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model, \
enable the eps_greedy_sample for exploration.
Initialize the collect mode of policy, including related attributes and modules. For DQN, it contains the \
collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism, and other \
algorithm-specific arguments such as unroll_len and nstep.
This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
.. note::
If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
.. tip::
Some variables need to initialize independently in different modes, such as gamma and nstep in DQN. This \
design is for the convenience of parallel execution of different policy modes.
"""
self._unroll_len = self._cfg.collect.unroll_len
self._gamma = self._cfg.discount_factor # necessary for parallel
Expand Down Expand Up @@ -408,9 +437,15 @@ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.
return transition

def _init_eval(self) -> None:
r"""
"""
Overview:
Evaluate mode init method. Called by ``self.__init__``, initialize eval_model.
Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \
eval model to greedily select action with argmax q_value mechanism.
This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
.. note::
If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
"""
self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._eval_model.reset()
Expand Down Expand Up @@ -455,6 +490,7 @@ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = F
Calculate priority for replay buffer.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training.
- update_target_model (:obj:`bool`): Whether to update target model.
Returns:
- priority (:obj:`Dict[str, Any]`): Dict type priority data, values are python scalar or a list of scalars.
ArgumentsKeys:
Expand Down Expand Up @@ -620,8 +656,20 @@ class DQNSTDIMPolicy(DQNPolicy):
def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``. First call super class's ``_init_lear`` method, then \
nitialize the auxiliary model, its optimizer, and the axuliary loss weight to the main loss.
Initialize the learn mode of policy, including related attributes and modules. For DQNSTDIM, it first \
call super class's ``_init_learn`` method, then initialize extra auxiliary model, its optimizer, and the \
loss weight. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
super()._init_learn()
x_size, y_size = self._get_encoding_size()
Expand Down Expand Up @@ -669,19 +717,28 @@ def _model_encode(self, data: dict) -> Tuple[torch.Tensor]:
def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Forward computation graph of learn mode(updating policy).
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss, q value, priority, aux_loss.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
dimension by some utility functions such as ``default_preprocess_learn``. \
For DQNSTDIM, each element in list is a dict containing at least the following keys: ``obs``, \
``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as \
``weight`` and ``value_gamma``.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``, ``IS``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``, ``priority``
- optional: ``action_distribution``
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
data = default_preprocess_learn(
data,
Expand Down
2 changes: 1 addition & 1 deletion ding/policy/policy_factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Dict, Any, Callable
from collections import namedtuple
from easydict import EasyDict
import gym
import torch

from ding.torch_utils import to_device
import gym


class PolicyFactory:
Expand Down

0 comments on commit e10dc80

Please sign in to comment.