Skip to content

Commit

Permalink
#641 code review: refactor PebbleStateEntropyReward so that inner Rew…
Browse files Browse the repository at this point in the history
…ardNets are initialized in constructor
  • Loading branch information
Jan Michelfeit committed Dec 10, 2022
1 parent 50577b0 commit a3369d4
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 40 deletions.
66 changes: 33 additions & 33 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,42 @@
from imitation.util.networks import RunningNorm


class PebbleRewardPhase(enum.Enum):
"""States representing different behaviors for PebbleStateEntropyReward."""

UNSUPERVISED_EXPLORATION = enum.auto() # Entropy based reward
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward


class InsufficientObservations(RuntimeError):
pass


class EntropyRewardNet(RewardNet):
class EntropyRewardNet(RewardNet, ReplayBufferAwareRewardFn):
def __init__(
self,
nearest_neighbor_k: int,
replay_buffer_view: ReplayBufferView,
observation_space: gym.Space,
action_space: gym.Space,
normalize_images: bool = True,
replay_buffer_view: Optional[ReplayBufferView] = None,
):
"""Initialize the RewardNet.
Args:
nearest_neighbor_k: Parameter for entropy computation (see
compute_state_entropy())
observation_space: the observation space of the environment
action_space: the action space of the environment
normalize_images: whether to automatically normalize
image observations to [0, 1] (from 0 to 255). Defaults to True.
replay_buffer_view: Replay buffer view with observations to compare
against when computing entropy. If None is given, the buffer needs to
be set with on_replay_buffer_initialized() before EntropyRewardNet can
be used
"""
super().__init__(observation_space, action_space, normalize_images)
self.nearest_neighbor_k = nearest_neighbor_k
self._replay_buffer_view = replay_buffer_view

def set_replay_buffer(self, replay_buffer: ReplayBufferRewardWrapper):
"""This method needs to be called after unpickling.
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
"""Sets replay buffer.
See also __getstate__() / __setstate__()
This method needs to be called, e.g., after unpickling.
See also __getstate__() / __setstate__().
"""
assert self.observation_space == replay_buffer.observation_space
assert self.action_space == replay_buffer.action_space
Expand Down Expand Up @@ -111,6 +111,13 @@ def __setstate__(self, state):
self._replay_buffer_view = None


class PebbleRewardPhase(enum.Enum):
"""States representing different behaviors for PebbleStateEntropyReward."""

UNSUPERVISED_EXPLORATION = enum.auto() # Entropy based reward
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward


class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
"""Reward function for implementation of the PEBBLE learning algorithm.
Expand All @@ -126,14 +133,15 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
reward is returned.
The second phase requires that a buffer with observations to compare against is
supplied with set_replay_buffer() or on_replay_buffer_initialized().
To transition to the last phase, unsupervised_exploration_finish() needs
to be called.
supplied with on_replay_buffer_initialized(). To transition to the last phase,
unsupervised_exploration_finish() needs to be called.
"""

def __init__(
self,
learned_reward_fn: RewardFn,
observation_space: gym.Space,
action_space: gym.Space,
nearest_neighbor_k: int = 5,
):
"""Builds this class.
Expand All @@ -146,28 +154,20 @@ def __init__(
"""
self.learned_reward_fn = learned_reward_fn
self.nearest_neighbor_k = nearest_neighbor_k

self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION

# These two need to be set with set_replay_buffer():
self._entropy_reward_net: Optional[EntropyRewardNet] = None
self._normalized_entropy_reward_net: Optional[RewardNet] = None
self._entropy_reward_net = EntropyRewardNet(
nearest_neighbor_k=self.nearest_neighbor_k,
observation_space=observation_space,
action_space=action_space,
normalize_images=False,
)
self._normalized_entropy_reward_net = NormalizedRewardNet(
self._entropy_reward_net, RunningNorm
)

def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
if self._normalized_entropy_reward_net is None:
self._entropy_reward_net = EntropyRewardNet(
nearest_neighbor_k=self.nearest_neighbor_k,
replay_buffer_view=replay_buffer.buffer_view,
observation_space=replay_buffer.observation_space,
action_space=replay_buffer.action_space,
normalize_images=False,
)
self._normalized_entropy_reward_net = NormalizedRewardNet(
self._entropy_reward_net, RunningNorm
)
else:
assert self._entropy_reward_net is not None
self._entropy_reward_net.set_replay_buffer(replay_buffer)
self._entropy_reward_net.on_replay_buffer_initialized(replay_buffer)

def unsupervised_exploration_finish(self):
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION
Expand Down
4 changes: 3 additions & 1 deletion src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def make_reward_function(
if pebble_enabled:
relabel_reward_fn = PebbleStateEntropyReward(
relabel_reward_fn, # type: ignore[assignment]
pebble_nearest_neighbor_k,
observation_space=reward_net.observation_space,
action_space=reward_net.action_space,
nearest_neighbor_k=pebble_nearest_neighbor_k,
)
return relabel_reward_fn

Expand Down
11 changes: 6 additions & 5 deletions tests/algorithms/pebble/test_entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import numpy as np
import torch as th
from gym.spaces import Discrete, Box
from gym.spaces import Box
from gym.spaces.space import Space

from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
from imitation.util import util
Expand All @@ -23,7 +24,7 @@
def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
all_observations = rng.random((BUFFER_SIZE, VENVS) + SPACE.shape)

reward_fn = PebbleStateEntropyReward(Mock(), K)
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, SPACE, K)
reward_fn.on_replay_buffer_initialized(
replay_buffer_mock(
ReplayBufferView(all_observations, lambda: slice(None)),
Expand All @@ -50,7 +51,7 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
# only stats collection in this test
m.side_effect = lambda obs, all_obs, k: obs

reward_fn = PebbleStateEntropyReward(Mock(), K)
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, SPACE, K)
all_observations = np.empty((BUFFER_SIZE, VENVS, *SPACE.shape))
reward_fn.on_replay_buffer_initialized(
replay_buffer_mock(
Expand Down Expand Up @@ -88,7 +89,7 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin
expected_reward = np.ones(1)
learned_reward_mock = Mock()
learned_reward_mock.return_value = expected_reward
reward_fn = PebbleStateEntropyReward(learned_reward_mock)
reward_fn = PebbleStateEntropyReward(learned_reward_mock, SPACE, SPACE)
# move all the way to the last state
reward_fn.unsupervised_exploration_finish()

Expand All @@ -111,7 +112,7 @@ def test_pebble_entropy_reward_can_pickle():
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))

obs1 = np.random.rand(VENVS, *SPACE.shape)
reward_fn = PebbleStateEntropyReward(reward_fn_stub, K)
reward_fn = PebbleStateEntropyReward(reward_fn_stub, SPACE, SPACE, K)
reward_fn.on_replay_buffer_initialized(replay_buffer_mock(replay_buffer, SPACE))
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/test_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer):
replay_buffer_mock = Mock()
replay_buffer_mock.buffer_view = replay_buffer
replay_buffer_mock.obs_shape = (4,)
reward_fn = PebbleStateEntropyReward(reward_net.predict_processed)
reward_fn = PebbleStateEntropyReward(reward_net.predict_processed, venv.observation_space, venv.action_space)
reward_fn.on_replay_buffer_initialized(replay_buffer_mock)
return preference_comparisons.PebbleAgentTrainer(
algorithm=agent,
Expand Down

0 comments on commit a3369d4

Please sign in to comment.