Skip to content

Commit

Permalink
fix(nyz): fix EDAC bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jan 23, 2024
1 parent 25a0d4d commit a57bc30
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
10 changes: 10 additions & 0 deletions ding/policy/edac.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ class EDACPolicy(SACPolicy):
),
)

def default_model(self) -> Tuple[str, List[str]]:
"""
Overview:
Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
automatically call this method to get the default model setting and create model.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
"""
return 'edac', ['ding.model.template.edac']

def _init_learn(self) -> None:
r"""
Overview:
Expand Down
11 changes: 6 additions & 5 deletions ding/policy/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,11 +716,12 @@ def _init_learn(self) -> None:
self._twin_critic = self._cfg.model.twin_critic

# Weight Init for the last output layer
init_w = self._cfg.learn.init_w
self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w)
self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w)
self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w)
self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w)
if hasattr(self._model, 'actor_head'): # keep compatibility
init_w = self._cfg.learn.init_w
self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w)
self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w)
self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w)
self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w)

self._optimizer_q = Adam(
self._model.critic.parameters(),
Expand Down
4 changes: 4 additions & 0 deletions ding/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def __init__(self, cfg: dict) -> None:
self._data = []
self._load_d4rl(dataset)

@property
def data(self) -> List:
return self._data

def __len__(self) -> int:
"""
Overview:
Expand Down

0 comments on commit a57bc30

Please sign in to comment.