Skip to content

Commit

Permalink
polish(pu): polish comments in a2c/bcq/fqf/ibc policy (#768)
Browse files Browse the repository at this point in the history
* polish(pu): polish comments in dqn and a2c

* polish(pu): polish comments in bcq

* polish(pu): polish comments in fqf

* polish(pu): polish comments in ibc

* style(pu): flake8 format

* polish(pu): polish config comments

* polish(pu): fix some typos and comments
  • Loading branch information
puyuan1996 authored Jan 25, 2024
1 parent a57bc30 commit 74c6a1e
Show file tree
Hide file tree
Showing 9 changed files with 612 additions and 249 deletions.
2 changes: 1 addition & 1 deletion ding/example/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def main():
# If you don't have offline data, you need to prepare if first and set the data_path in config
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
# For demonstration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
Expand Down
261 changes: 180 additions & 81 deletions ding/policy/a2c.py

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions ding/policy/acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ class ACERPolicy(Policy):
config = dict(
type='acer',
cuda=False,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
# (bool) whether to use on-policy training pipeline (behaviour policy and training policy are the same)
# here we follow ppo serial pipeline, the original is False
on_policy=False,
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
# (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
learn=dict(
# (str) the type of gradient clip method
Expand Down Expand Up @@ -295,7 +295,7 @@ def _reshape_data(
Update values and rewards with the weight
Arguments:
- output (:obj:`Dict[int, Any]`): Dict type data, output of learn_model forward. \
Values are torch.Tensor or np.ndarray or dict/list combinations,keys are value, logit.
Values are torch.Tensor or np.ndarray or dict/list combinations, keys are value, logit.
- data (:obj:`Dict[int, Any]`): Dict type data, input of policy._forward_learn \
Values are torch.Tensor or np.ndarray or dict/list combinations. Keys includes at \
least ['logit', 'action', 'reward', 'done',]
Expand Down Expand Up @@ -378,7 +378,7 @@ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Dict[str, Any]]:
action, values are torch.Tensor or np.ndarray or dict/list combinations,keys \
are env_id indicated by integer.
Returns:
- output (:obj:`Dict[int, Dict[str,Any]]`): Dict of predicting policy_output(logit, action) for each env.
- output (:obj:`Dict[int, Dict[str, Any]]`): Dict of predicting policy_output(logit, action) for each env.
ReturnsKeys
- necessary: ``logit``, ``action``
"""
Expand Down Expand Up @@ -479,7 +479,7 @@ def _monitor_vars_learn(self) -> List[str]:
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
.. note::
The user can define and use customized network model but must obey the same interface definition indicated \
by import_names path. For IMPALA, ``ding.model.interface.IMPALA``
The user can define and use a customized network model but must obey the same interface definition \
indicated by import_names path. For IMPALA, ``ding.model.interface.IMPALA``
"""
return ['actor_loss', 'bc_loss', 'policy_loss', 'critic_loss', 'entropy_loss', 'kl_div']
20 changes: 10 additions & 10 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ def hook(*ignore):
def _create_model(self, cfg: EasyDict, model: Optional[torch.nn.Module] = None) -> torch.nn.Module:
"""
Overview:
Create or validate the neural network model according to input configures and model. If the input model is \
None, then the model will be created according to ``default_model`` method and ``cfg.model`` field. \
Otherwise, the model will be verified as an instance of ``torch.nn.Module`` and set to the ``model`` \
instance created by outside caller.
Create or validate the neural network model according to the input configuration and model. \
If the input model is None, then the model will be created according to ``default_model`` \
method and ``cfg.model`` field. Otherwise, the model will be verified as an instance of \
``torch.nn.Module`` and set to the ``model`` instance created by outside caller.
Arguments:
- cfg (:obj:`EasyDict`): The final merged config used to initialize policy.
- model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. User can refer to \
the default model defined in corresponding policy to customize its own model.
the default model defined in the corresponding policy to customize its own model.
Returns:
- model (:obj:`torch.nn.Module`): The created neural network model. The different modes of policy will \
add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
Expand Down Expand Up @@ -272,7 +272,7 @@ def _init_eval(self) -> None:
Overview:
Initialize the eval mode of policy, including related attributes and modules. This method will be \
called in ``__init__`` method if ``eval`` field is in ``enable_field``. Almost different policies have \
its own eval mode, so this method must be overrided in subclass.
its own eval mode, so this method must be override in subclass.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_eval`` \
Expand All @@ -289,7 +289,7 @@ def learn_mode(self) -> 'Policy.learn_function': # noqa
"""
Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple \
to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived \
subclass can override the interfaces to customize its own learn mode.
Returns:
- interfaces (:obj:`Policy.learn_function`): The interfaces of learn mode of policy, it is a namedtuple \
Expand All @@ -316,7 +316,7 @@ def collect_mode(self) -> 'Policy.collect_function': # noqa
"""
Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple \
to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived \
subclass can override the interfaces to customize its own collect mode.
Returns:
- interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a \
Expand Down Expand Up @@ -370,7 +370,7 @@ def _set_attribute(self, name: str, value: Any) -> None:
Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than \
directly use the policy instance. And we also provide a method to set the attribute of the policy in \
different modes. And the new attribute will named as ``_{name}``.
different modes. And the new attribute will name as ``_{name}``.
Arguments:
- name (:obj:`str`): The name of the attribute.
- value (:obj:`Any`): The value of the attribute.
Expand Down Expand Up @@ -416,7 +416,7 @@ def sync_gradients(self, model: torch.nn.Module) -> None:
- model (:obj:`torch.nn.Module`): The model to synchronize gradients.
.. note::
This method is only used in multi-gpu training, and it shoule be called after ``backward`` method and \
This method is only used in multi-gpu training, and it should be called after ``backward`` method and \
before ``step`` method. The user can also use ``bp_update_sync`` config to control whether to synchronize \
gradients allreduce and optimizer updates.
"""
Expand Down
Loading

0 comments on commit 74c6a1e

Please sign in to comment.