Skip to content

Commit

Permalink
#625 add initialized callback to ReplayBufferRewardWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent c681ca3 commit ad29c34
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/imitation/policies/replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrapper for reward labeling for transitions sampled from a replay buffer."""

from typing import Callable
from typing import Mapping, Type

import numpy as np
Expand All @@ -10,7 +11,6 @@
from imitation.rewards.reward_function import RewardFn
from imitation.util import util
from imitation.util.networks import RunningNorm
from typing import Callable


def _samples_to_reward_fn_input(
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(
*,
replay_buffer_class: Type[ReplayBuffer],
reward_fn: RewardFn,
on_initialized_callback: Callable[["ReplayBufferRewardWrapper"], None] = None,
**kwargs,
):
"""Builds ReplayBufferRewardWrapper.
Expand All @@ -69,6 +70,9 @@ def __init__(
action_space: Action space
replay_buffer_class: Class of the replay buffer.
reward_fn: Reward function for reward relabeling.
on_initialized_callback: Callback called with reference to this object after
this instance is fully initialized. This provides a hook to access the
buffer after it is created from inside a Stable Baselines algorithm.
**kwargs: keyword arguments for ReplayBuffer.
"""
# Note(yawen-d): we directly inherit ReplayBuffer and leave out the case of
Expand All @@ -86,6 +90,8 @@ def __init__(
self.reward_fn = reward_fn
_base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]}
super().__init__(buffer_size, observation_space, action_space, **_base_kwargs)
if on_initialized_callback is not None:
on_initialized_callback(self)

# TODO(juan) remove the type ignore once the merged PR
# https://github.com/python/mypy/pull/13475
Expand Down
15 changes: 15 additions & 0 deletions tests/policies/test_replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,18 @@ def test_replay_buffer_view_provides_buffered_observations():
# ReplayBuffer internally uses a circular buffer
expected = np.roll(observations, 1, axis=0)
np.testing.assert_allclose(view.observations, expected)


def test_replay_buffer_reward_wrapper_calls_initialization_callback_with_itself():
callback = Mock()
buffer = ReplayBufferRewardWrapper(
10,
spaces.Discrete(2),
spaces.Discrete(2),
replay_buffer_class=ReplayBuffer,
reward_fn=Mock(),
n_envs=2,
handle_timeout_termination=False,
on_initialized_callback=callback,
)
assert callback.call_args.args[0] is buffer

0 comments on commit ad29c34

Please sign in to comment.