Skip to content

Commit

Permalink
polish(nyz): complete ddpg/bc comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Oct 16, 2023
1 parent 1a3e259 commit 33ea61b
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 53 deletions.
104 changes: 80 additions & 24 deletions ding/policy/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ class BehaviourCloningPolicy(Policy):
max=0.5,
),
),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=10000, )),
eval=dict(), # for compatibility
)

def default_model(self) -> Tuple[str, List[str]]:
Expand All @@ -79,8 +78,25 @@ def default_model(self) -> Tuple[str, List[str]]:
else:
return 'discrete_bc', ['ding.model.template.bc']

def _init_learn(self):
assert self._cfg.learn.optimizer in ['SGD', 'Adam']
def _init_learn(self) -> None:
"""
Overview:
Initialize the learn mode of policy, including related attributes and modules. For BC, it mainly contains \
optimizer, algorithm-specific arguments such as lr_scheduler, loss, etc. \
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
assert self._cfg.learn.optimizer in ['SGD', 'Adam'], self._cfg.learn.optimizer
if self._cfg.learn.optimizer == 'SGD':
self._optimizer = SGD(
self._model.parameters(),
Expand Down Expand Up @@ -120,20 +136,38 @@ def lr_scheduler_fn(epoch):
elif self._cfg.loss_type == 'mse_loss':
self._loss = nn.MSELoss()
else:
raise KeyError
raise KeyError("not support loss type: {}".format(self._cfg.loss_type))
else:
if not self._cfg.learn.ce_label_smooth:
self._loss = nn.CrossEntropyLoss()
else:
self._loss = LabelSmoothCELoss(0.1)

if self._cfg.learn.show_accuracy:
# accuracy statistics for debugging in discrete action space env, e.g. for gfootball
self.total_accuracy_in_dataset = []
self.action_accuracy_in_dataset = {k: [] for k in range(self._cfg.action_shape)}
def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss and time.
Arguments:
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
dimension by some utility functions such as ``default_preprocess_learn``. \
For BC, each element in list is a dict containing at least the following keys: ``obs``, ``action``.
Returns:
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
def _forward_learn(self, data):
if not isinstance(data, dict):
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
if isinstance(data, list):
data = default_collate(data)
if self._cuda:
data = to_device(data, self._device)
Expand All @@ -142,10 +176,10 @@ def _forward_learn(self, data):
obs, action = data['obs'], data['action'].squeeze()
if self._cfg.continuous:
if self._cfg.learn.tanh_mask:
'''
"""tanh_mask
We mask the action out of range of [tanh(-1),tanh(1)], model will learn information
and produce action in [-1,1]. So the action won't always converge to -1 or 1.
'''
"""
mu = self._eval_model.forward(data['obs'])['action']
bound = 1 - 2 / (math.exp(2) + 1) # tanh(1): (e-e**(-1))/(e+e**(-1))
mask = mu.ge(-bound) & mu.le(bound)
Expand Down Expand Up @@ -200,7 +234,7 @@ def _forward_learn(self, data):
'sync_time': sync_time,
}

def _monitor_vars_learn(self):
def _monitor_vars_learn(self) -> List[str]:
"""
Overview:
Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
Expand All @@ -211,24 +245,46 @@ def _monitor_vars_learn(self):
return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time']

def _init_eval(self):
"""
Overview:
Initialize the eval mode of policy, including related attributes and modules. For BC, it contains the \
eval model to greedily select action with argmax q_value mechanism for discrete action space.
This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
.. note::
If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
"""
if self._cfg.continuous:
self._eval_model = model_wrap(self._model, wrapper_name='base')
else:
self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._eval_model.reset()

def _forward_eval(self, data):
gfootball_flag = False
def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
"""
Overview:
Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
action to interact with the envs.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
tensor_input = isinstance(data, torch.Tensor)
if tensor_input:
data = default_collate(list(data))
else:
data_id = list(data.keys())
if data_id == ['processed_obs', 'raw_obs']:
# for gfootball
gfootball_flag = True
data = {0: data}
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
Expand All @@ -237,7 +293,7 @@ def _forward_eval(self, data):
output = self._eval_model.forward(data)
if self._cuda:
output = to_device(output, 'cpu')
if tensor_input or gfootball_flag:
if tensor_input:
return output
else:
output = default_decollate(output)
Expand Down Expand Up @@ -282,11 +338,11 @@ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtuple) -> dict:
transition = {
'obs': obs,
'next_obs': timestep.obs,
'action': model_output['action'],
'action': policy_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
Expand Down
Loading

0 comments on commit 33ea61b

Please sign in to comment.