diff --git a/ding/policy/cql.py b/ding/policy/cql.py index af8675859c..0910ebe262 100644 --- a/ding/policy/cql.py +++ b/ding/policy/cql.py @@ -152,8 +152,22 @@ class CQLPolicy(SACPolicy): def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init q and policy's optimizers, algorithm config, main and target models. + Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange and \ + with_q_entropy, main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \ + target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -241,14 +255,30 @@ def _init_learn(self) -> None: self._forward_learn_cnt = 0 - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: - Forward and backward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. """ loss_dict = {} data = default_preprocess_learn( @@ -509,8 +539,20 @@ class DiscreteCQLPolicy(QRDQNPolicy): def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init the optimizer, algorithm config, main and target models. + Initialize the learn mode of policy, including related attributes and modules. For DiscreteCQL, it mainly \ + contains the optimizer, algorithm-specific arguments such as gamma, nstep and min_q_weight, main and \ + target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._min_q_weight = self._cfg.learn.min_q_weight self._priority = self._cfg.priority @@ -532,14 +574,31 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: - Forward and backward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For DiscreteCQL, each element in list is a dict containing at least the following keys: ``obs``, \ + ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight`` \ + and ``value_gamma`` for nstep return computation. Returns: - - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. """ data = default_preprocess_learn( data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True diff --git a/ding/policy/ddpg.py b/ding/policy/ddpg.py index fa830d2da1..a4dd5dc4bf 100644 --- a/ding/policy/ddpg.py +++ b/ding/policy/ddpg.py @@ -491,7 +491,7 @@ def _init_eval(self) -> None: self._eval_model = model_wrap(self._eval_model, wrapper_name='hybrid_argmax_sample') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: """ Overview: Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ diff --git a/ding/policy/sac.py b/ding/policy/sac.py index dd3c64bacd..5ee3cc3eda 100644 --- a/ding/policy/sac.py +++ b/ding/policy/sac.py @@ -124,8 +124,21 @@ def default_model(self) -> Tuple[str, List[str]]: def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init q function and policy's optimizers, algorithm config, main and target models. + Initialize the learn mode of policy, including related attributes and modules. For DiscreteSAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \ + model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -179,10 +192,34 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: - Forward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``logit``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \ + ``ding.policy.tests.test_discrete_sac``. """ loss_dict = {} data = default_preprocess_learn( @@ -348,7 +385,14 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: def _init_collect(self) -> None: """ Overview: - Initialize the collect_mode of policy, mainly including collect_model. + Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ + collect_model to balance the exploration and exploitation with the epsilon and multinomial sample \ + mechanism, and other algorithm-specific arguments such as unroll_len. \ + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. """ self._unroll_len = self._cfg.collect.unroll_len # Empirically, we found that eps_greedy_multinomial_sample works better than multinomial_sample @@ -357,10 +401,31 @@ def _init_collect(self) -> None: self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample') self._collect_model.reset() - def _forward_collect(self, data: dict, eps: float) -> dict: + def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: """ Overview: - Forward function for collect_mode. + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ + exploration, i.e., classic epsilon-greedy exploration strategy. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + - eps (:obj:`float`): The epsilon value for exploration. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \ + ``ding.policy.tests.test_discrete_sac``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -417,15 +482,39 @@ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, def _init_eval(self) -> None: """ Overview: - Initialize the eval_mode of policy, mainly including eval_model. + Initialize the eval mode of policy, including related attributes and modules. For DiscreteSAC, it contains \ + the eval model to greedily select action type with argmax q_value mechanism. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: """ Overview: - Forward function for eval_mode. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \ + ``ding.policy.tests.test_discrete_sac``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -606,7 +695,21 @@ def default_model(self) -> Tuple[str, List[str]]: def _init_learn(self) -> None: """ Overview: - Initialize the learn model and algorithm related object. + Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \ + model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -667,10 +770,33 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: - Forward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. """ loss_dict = {} data = default_preprocess_learn( @@ -832,16 +958,43 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: def _init_collect(self) -> None: """ Overview: - Initialize the collect mode. + Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ + collect_model other algorithm-specific arguments such as unroll_len. \ + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. """ self._unroll_len = self._cfg.collect.unroll_len self._collect_model = model_wrap(self._model, wrapper_name='base') self._collect_model.reset() - def _forward_collect(self, data: dict) -> dict: + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: """ Overview: - Forward function of collect mode. + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -912,15 +1065,41 @@ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, def _init_eval(self) -> None: """ Overview: - Initialize the eval mode. + Initialize the eval mode of policy, including related attributes and modules. For SAC, it contains the \ + eval model, which is equipped with ``base`` model wrapper to ensure compability. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='base') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: """ Overview: - Forward function of eval mode. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -972,7 +1151,21 @@ class SQILSACPolicy(SACPolicy): def _init_learn(self) -> None: """ Overview: - Initialize the learn mode. + Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \ + model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -1037,10 +1230,37 @@ def _init_learn(self) -> None: self._monitor_cos = True self._monitor_entropy = True - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: - Forward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + For SQIL + SAC, input data is composed of two parts with the same size: agent data and expert data. \ + Both of them are relabelled with new reward according to SQIL algorithm. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. """ loss_dict = {} if self._monitor_cos: