From 33ea61b6dea4b0466d1bf9dfa4039698100137ab Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Mon, 16 Oct 2023 18:57:15 +0800 Subject: [PATCH] polish(nyz): complete ddpg/bc comments --- ding/policy/bc.py | 104 +++++++++++++++++++++++++++++--------- ding/policy/ddpg.py | 118 +++++++++++++++++++++++++++++++++----------- 2 files changed, 169 insertions(+), 53 deletions(-) diff --git a/ding/policy/bc.py b/ding/policy/bc.py index 57049bd773..0c95b8abec 100644 --- a/ding/policy/bc.py +++ b/ding/policy/bc.py @@ -57,8 +57,7 @@ class BehaviourCloningPolicy(Policy): max=0.5, ), ), - eval=dict(), - other=dict(replay_buffer=dict(replay_buffer_size=10000, )), + eval=dict(), # for compatibility ) def default_model(self) -> Tuple[str, List[str]]: @@ -79,8 +78,25 @@ def default_model(self) -> Tuple[str, List[str]]: else: return 'discrete_bc', ['ding.model.template.bc'] - def _init_learn(self): - assert self._cfg.learn.optimizer in ['SGD', 'Adam'] + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For BC, it mainly contains \ + optimizer, algorithm-specific arguments such as lr_scheduler, loss, etc. \ + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + assert self._cfg.learn.optimizer in ['SGD', 'Adam'], self._cfg.learn.optimizer if self._cfg.learn.optimizer == 'SGD': self._optimizer = SGD( self._model.parameters(), @@ -120,20 +136,38 @@ def lr_scheduler_fn(epoch): elif self._cfg.loss_type == 'mse_loss': self._loss = nn.MSELoss() else: - raise KeyError + raise KeyError("not support loss type: {}".format(self._cfg.loss_type)) else: if not self._cfg.learn.ce_label_smooth: self._loss = nn.CrossEntropyLoss() else: self._loss = LabelSmoothCELoss(0.1) - if self._cfg.learn.show_accuracy: - # accuracy statistics for debugging in discrete action space env, e.g. for gfootball - self.total_accuracy_in_dataset = [] - self.action_accuracy_in_dataset = {k: [] for k in range(self._cfg.action_shape)} + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss and time. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For BC, each element in list is a dict containing at least the following keys: ``obs``, ``action``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. - def _forward_learn(self, data): - if not isinstance(data, dict): + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ + if isinstance(data, list): data = default_collate(data) if self._cuda: data = to_device(data, self._device) @@ -142,10 +176,10 @@ def _forward_learn(self, data): obs, action = data['obs'], data['action'].squeeze() if self._cfg.continuous: if self._cfg.learn.tanh_mask: - ''' + """tanh_mask We mask the action out of range of [tanh(-1),tanh(1)], model will learn information and produce action in [-1,1]. So the action won't always converge to -1 or 1. - ''' + """ mu = self._eval_model.forward(data['obs'])['action'] bound = 1 - 2 / (math.exp(2) + 1) # tanh(1): (e-e**(-1))/(e+e**(-1)) mask = mu.ge(-bound) & mu.le(bound) @@ -200,7 +234,7 @@ def _forward_learn(self, data): 'sync_time': sync_time, } - def _monitor_vars_learn(self): + def _monitor_vars_learn(self) -> List[str]: """ Overview: Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ @@ -211,24 +245,46 @@ def _monitor_vars_learn(self): return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time'] def _init_eval(self): + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For BC, it contains the \ + eval model to greedily select action with argmax q_value mechanism for discrete action space. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ if self._cfg.continuous: self._eval_model = model_wrap(self._model, wrapper_name='base') else: self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() - def _forward_eval(self, data): - gfootball_flag = False + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ tensor_input = isinstance(data, torch.Tensor) if tensor_input: data = default_collate(list(data)) else: data_id = list(data.keys()) - if data_id == ['processed_obs', 'raw_obs']: - # for gfootball - gfootball_flag = True - data = {0: data} - data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) @@ -237,7 +293,7 @@ def _forward_eval(self, data): output = self._eval_model.forward(data) if self._cuda: output = to_device(output, 'cpu') - if tensor_input or gfootball_flag: + if tensor_input: return output else: output = default_decollate(output) @@ -282,11 +338,11 @@ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtuple) -> dict: transition = { 'obs': obs, 'next_obs': timestep.obs, - 'action': model_output['action'], + 'action': policy_output['action'], 'reward': timestep.reward, 'done': timestep.done, } diff --git a/ding/policy/ddpg.py b/ding/policy/ddpg.py index 2ddb2c4af5..fa830d2da1 100644 --- a/ding/policy/ddpg.py +++ b/ding/policy/ddpg.py @@ -151,8 +151,20 @@ def default_model(self) -> Tuple[str, List[str]]: def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init actor and critic optimizers, algorithm config, main and target models. + Initialize the learn mode of policy, including related attributes and modules. For DDPG, it mainly \ + contains two optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target model. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -173,7 +185,9 @@ def _init_learn(self) -> None: # main and target models self._target_model = copy.deepcopy(self._model) + self._learn_model = model_wrap(self._model, wrapper_name='base') if self._cfg.action_space == 'hybrid': + self._learn_model = model_wrap(self._learn_model, wrapper_name='hybrid_argmax_sample') self._target_model = model_wrap(self._target_model, wrapper_name='hybrid_argmax_sample') self._target_model = model_wrap( self._target_model, @@ -192,23 +206,39 @@ def _init_learn(self) -> None: }, noise_range=self._cfg.learn.noise_range ) - self._learn_model = model_wrap(self._model, wrapper_name='base') - if self._cfg.action_space == 'hybrid': - self._learn_model = model_wrap(self._learn_model, wrapper_name='hybrid_argmax_sample') self._learn_model.reset() self._target_model.reset() self._forward_learn_cnt = 0 # count iterations - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: - Forward and backward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']. + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For DDPG, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ + and ``logit`` which is used for hybrid action space. Returns: - - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ - recorded in text log and tensorboard, values are python scalar or a list of scalars. + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``. """ loss_dict = {} data = default_preprocess_learn( @@ -344,8 +374,14 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: def _init_collect(self) -> None: """ Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. + Initialize the collect mode of policy, including related attributes and modules. For DDPG, it contains the \ + collect_model to balance the exploration and exploitation with the perturbed noise mechanism, and other \ + algorithm-specific arguments such as unroll_len. \ + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. """ self._unroll_len = self._cfg.collect.unroll_len # collect model @@ -363,18 +399,28 @@ def _init_collect(self) -> None: self._collect_model = model_wrap(self._collect_model, wrapper_name='hybrid_eps_greedy_multinomial_sample') self._collect_model.reset() - def _forward_collect(self, data: dict, **kwargs) -> dict: + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: """ Overview: - Forward function of collect mode. + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. - ReturnsKeys - - necessary: ``action`` - - optional: ``logit`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -432,8 +478,13 @@ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, def _init_eval(self) -> None: """ Overview: - Evaluate mode init method. Called by ``self.__init__``. - Init eval model. Unlike learn and collect model, eval model does not need noise. + Initialize the eval mode of policy, including related attributes and modules. For DDPG, it contains the \ + eval model to greedily select action type with argmax q_value mechanism for hybrid action space. \ + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='base') if self._cfg.action_space == 'hybrid': @@ -443,15 +494,24 @@ def _init_eval(self) -> None: def _forward_eval(self, data: dict) -> dict: """ Overview: - Forward function of eval mode, similar to ``self._forward_collect``. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ReturnsKeys - - necessary: ``action`` - - optional: ``logit`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``. """ data_id = list(data.keys()) data = default_collate(list(data.values()))