From e10dc80903216f696fd84881a9e8bbab83982592 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 10 Oct 2023 15:59:45 +0800 Subject: [PATCH] polish(nyz): complete dqn comments --- ding/policy/base_policy.py | 17 ++--- ding/policy/dqn.py | 135 ++++++++++++++++++++++++---------- ding/policy/policy_factory.py | 2 +- 3 files changed, 105 insertions(+), 49 deletions(-) diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 0d7694e737..2b5f13a541 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -1,10 +1,10 @@ +from typing import Optional, List, Dict, Any, Tuple, Union from abc import ABC, abstractmethod from collections import namedtuple -from typing import Optional, List, Dict, Any, Tuple, Union +from easydict import EasyDict -import torch import copy -from easydict import EasyDict +import torch from ding.model import create_model from ding.utils import import_module, allreduce, broadcast, get_rank, allreduce_async, synchronize, deep_merge_dicts, \ @@ -446,7 +446,7 @@ def default_model(self) -> Tuple[str, List[str]]: # *************************************** learn function ************************************ @abstractmethod - def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: Policy forward function of learn mode (training policy and updating parameters). Forward means \ @@ -455,11 +455,10 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: and so on. This method is left to be implemented by the subclass, and more arguments can be added in \ ``data`` item if necessary. Arguments: - - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including a batch of training \ - samples. The key of the dict is the name of data items and the value is the corresponding data. \ - Usually, the date item is a list of data, and the first dimension is the batch dimension, then in \ - the ``_forward_learn`` method, the data item should be stacked in the batch dimension by some utility \ - methods such as ``default_preprocess_learn``. + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, in the ``_forward_learn`` method, data should be stacked in \ + the batch dimension by some utility functions such as ``default_preprocess_learn``. Returns: - output (:obj:`Dict[int, Any]`): The training information of policy forward, including some metrics for \ monitoring training such as loss, priority, q value, policy entropy, and some data for next step \ diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index e7c5595789..bf943a72ac 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -174,8 +174,20 @@ def default_model(self) -> Tuple[str, List[str]]: def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \ - and target model. + Initialize the learn mode of policy, including related attributes and modules. For DQN, it mainly contains \ + optimizer, algorithm-specific arguments such as nstep and gamma, main and target model. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -207,23 +219,36 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: - Forward computation graph of learn mode(updating policy). + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, q value, priority. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ - np.ndarray or dict/list combinations. + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For DQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ + and ``value_gamma``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ - recorded in text log and tensorboard, values are python scalar or a list of scalars. - ArgumentsKeys: - - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` - - optional: ``value_gamma`` - ReturnsKeys: - - necessary: ``cur_lr``, ``total_loss``, ``priority`` - - optional: ``action_distribution`` + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. """ + # Data preprocessing operations, such as stack data, cpu to cuda device data = default_preprocess_learn( data, use_priority=self._priority, @@ -233,9 +258,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: ) if self._cuda: data = to_device(data, self._device) - # ==================== # Q-learning forward - # ==================== self._learn_model.train() self._target_model.train() # Current q value (main model) @@ -252,18 +275,14 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: value_gamma = data.get('value_gamma') loss, td_error_per_sample = q_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma) - # ==================== - # Q-learning update - # ==================== + # Update network parameters self._optimizer.zero_grad() loss.backward() if self._cfg.multi_gpu: self.sync_gradients(self._learn_model) self._optimizer.step() - # ============= - # after update - # ============= + # Postprocessing operations, such as updating target model, return logged values and priority. self._target_model.update(self._learn_model.state_dict()) return { 'cur_lr': self._optimizer.defaults['lr'], @@ -317,8 +336,18 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: def _init_collect(self) -> None: """ Overview: - Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model, \ - enable the eps_greedy_sample for exploration. + Initialize the collect mode of policy, including related attributes and modules. For DQN, it contains the \ + collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism, and other \ + algorithm-specific arguments such as unroll_len and nstep. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + + .. tip:: + Some variables need to initialize independently in different modes, such as gamma and nstep in DQN. This \ + design is for the convenience of parallel execution of different policy modes. """ self._unroll_len = self._cfg.collect.unroll_len self._gamma = self._cfg.discount_factor # necessary for parallel @@ -408,9 +437,15 @@ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch. return transition def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``, initialize eval_model. + Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \ + eval model to greedily select action with argmax q_value mechanism. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() @@ -455,6 +490,7 @@ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = F Calculate priority for replay buffer. Arguments: - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training. + - update_target_model (:obj:`bool`): Whether to update target model. Returns: - priority (:obj:`Dict[str, Any]`): Dict type priority data, values are python scalar or a list of scalars. ArgumentsKeys: @@ -620,8 +656,20 @@ class DQNSTDIMPolicy(DQNPolicy): def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. First call super class's ``_init_lear`` method, then \ - nitialize the auxiliary model, its optimizer, and the axuliary loss weight to the main loss. + Initialize the learn mode of policy, including related attributes and modules. For DQNSTDIM, it first \ + call super class's ``_init_learn`` method, then initialize extra auxiliary model, its optimizer, and the \ + loss weight. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ super()._init_learn() x_size, y_size = self._get_encoding_size() @@ -669,19 +717,28 @@ def _model_encode(self, data: dict) -> Tuple[torch.Tensor]: def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Overview: - Forward computation graph of learn mode(updating policy). + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, q value, priority, aux_loss. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ - np.ndarray or dict/list combinations. + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For DQNSTDIM, each element in list is a dict containing at least the following keys: ``obs``, \ + ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as \ + ``weight`` and ``value_gamma``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ - recorded in text log and tensorboard, values are python scalar or a list of scalars. - ArgumentsKeys: - - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` - - optional: ``value_gamma``, ``IS`` - ReturnsKeys: - - necessary: ``cur_lr``, ``total_loss``, ``priority`` - - optional: ``action_distribution`` + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. """ data = default_preprocess_learn( data, diff --git a/ding/policy/policy_factory.py b/ding/policy/policy_factory.py index fa682b9c3c..ba9b77df29 100644 --- a/ding/policy/policy_factory.py +++ b/ding/policy/policy_factory.py @@ -1,10 +1,10 @@ from typing import Dict, Any, Callable from collections import namedtuple from easydict import EasyDict +import gym import torch from ding.torch_utils import to_device -import gym class PolicyFactory: