Skip to content

Commit

Permalink
#641 code review: replace RunningNorm with NormalizedRewardNet
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 10, 2022
1 parent c80fb80 commit 50577b0
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 60 deletions.
144 changes: 111 additions & 33 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Reward function for the PEBBLE training algorithm."""

import enum
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple

import gym
import numpy as np
import torch as th

Expand All @@ -12,6 +13,7 @@
ReplayBufferView,
)
from imitation.rewards.reward_function import RewardFn
from imitation.rewards.reward_nets import NormalizedRewardNet, RewardNet
from imitation.util import util
from imitation.util.networks import RunningNorm

Expand All @@ -23,6 +25,92 @@ class PebbleRewardPhase(enum.Enum):
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward


class InsufficientObservations(RuntimeError):
pass


class EntropyRewardNet(RewardNet):
def __init__(
self,
nearest_neighbor_k: int,
replay_buffer_view: ReplayBufferView,
observation_space: gym.Space,
action_space: gym.Space,
normalize_images: bool = True,
):
"""Initialize the RewardNet.
Args:
observation_space: the observation space of the environment
action_space: the action space of the environment
normalize_images: whether to automatically normalize
image observations to [0, 1] (from 0 to 255). Defaults to True.
"""
super().__init__(observation_space, action_space, normalize_images)
self.nearest_neighbor_k = nearest_neighbor_k
self._replay_buffer_view = replay_buffer_view

def set_replay_buffer(self, replay_buffer: ReplayBufferRewardWrapper):
"""This method needs to be called after unpickling.
See also __getstate__() / __setstate__()
"""
assert self.observation_space == replay_buffer.observation_space
assert self.action_space == replay_buffer.action_space
self._replay_buffer_view = replay_buffer.buffer_view

def forward(
self,
state: th.Tensor,
action: th.Tensor,
next_state: th.Tensor,
done: th.Tensor,
) -> th.Tensor:
assert (
self._replay_buffer_view is not None
), "Missing replay buffer (possibly after unpickle)"

all_observations = self._replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape(
(-1,) + self.observation_space.shape
)

if all_observations.shape[0] < self.nearest_neighbor_k:
raise InsufficientObservations(
"Insufficient observations for entropy calculation"
)

return util.compute_state_entropy(
state, all_observations, self.nearest_neighbor_k
)

def preprocess(
self,
state: np.ndarray,
action: np.ndarray,
next_state: np.ndarray,
done: np.ndarray,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
"""Override default preprocessing to avoid the default one-hot encoding.
We also know forward() only works with state, so no need to convert
other tensors.
"""
state_th = util.safe_to_tensor(state).to(self.device)
action_th = next_state_th = done_th = th.empty(0)
return state_th, action_th, next_state_th, done_th

def __getstate__(self):
state = self.__dict__.copy()
del state["_replay_buffer_view"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._replay_buffer_view = None


class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
"""Reward function for implementation of the PEBBLE learning algorithm.
Expand Down Expand Up @@ -59,17 +147,27 @@ def __init__(
self.learned_reward_fn = learned_reward_fn
self.nearest_neighbor_k = nearest_neighbor_k

self.entropy_stats = RunningNorm(1)
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION

# These two need to be set with set_replay_buffer():
self.replay_buffer_view: Optional[ReplayBufferView] = None
self.obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]], None] = None
self._entropy_reward_net: Optional[EntropyRewardNet] = None
self._normalized_entropy_reward_net: Optional[RewardNet] = None

def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
self.replay_buffer_view = replay_buffer.buffer_view
self.obs_shape = replay_buffer.obs_shape

if self._normalized_entropy_reward_net is None:
self._entropy_reward_net = EntropyRewardNet(
nearest_neighbor_k=self.nearest_neighbor_k,
replay_buffer_view=replay_buffer.buffer_view,
observation_space=replay_buffer.observation_space,
action_space=replay_buffer.action_space,
normalize_images=False,
)
self._normalized_entropy_reward_net = NormalizedRewardNet(
self._entropy_reward_net, RunningNorm
)
else:
assert self._entropy_reward_net is not None
self._entropy_reward_net.set_replay_buffer(replay_buffer)

def unsupervised_exploration_finish(self):
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION
Expand All @@ -88,35 +186,15 @@ def __call__(
return self.learned_reward_fn(state, action, next_state, done)

def _entropy_reward(self, state, action, next_state, done):
if self.replay_buffer_view is None:
if self._normalized_entropy_reward_net is None:
raise ValueError(
"Replay buffer must be supplied before entropy reward can be used",
)
all_observations = self.replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape((-1, *self.obs_shape))

if all_observations.shape[0] < self.nearest_neighbor_k:
try:
return self._normalized_entropy_reward_net.predict_processed(
state, action, next_state, done, update_stats=True
)
except InsufficientObservations:
# not enough observations to compare to, fall back to the learned function;
# (falling back to a constant may also be ok)
return self.learned_reward_fn(state, action, next_state, done)
else:
# TODO #625: deal with the conversion back and forth between np and torch
entropies = util.compute_state_entropy(
th.tensor(state),
th.tensor(all_observations),
self.nearest_neighbor_k,
)

normalized_entropies = self.entropy_stats.forward(entropies)

return normalized_entropies.numpy()

def __getstate__(self):
state = self.__dict__.copy()
del state["replay_buffer_view"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.replay_buffer_view = None
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ def fast():
reward_trainer_kwargs = {
"epochs": 1,
}
locals() # quieten flake8
5 changes: 4 additions & 1 deletion src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,15 @@ def compute_state_entropy(
for idx in range(len(all_obs) // batch_size + 1):
start = idx * batch_size
end = (idx + 1) * batch_size
all_obs_batch = all_obs[start:end]
distances_tensor = th.linalg.vector_norm(
obs[:, None] - all_obs[None, start:end],
obs[:, None] - all_obs_batch[None, :],
dim=non_batch_dimensions,
ord=2,
)
assert distances_tensor.shape == (obs.shape[0], all_obs_batch.shape[0])
dists.append(distances_tensor)
all_dists = th.cat(dists, dim=1)
knn_dists = th.kthvalue(all_dists, k=k + 1, dim=1).values
return knn_dists

49 changes: 23 additions & 26 deletions tests/algorithms/pebble/test_entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@

import numpy as np
import torch as th
from gym.spaces import Discrete

from gym.spaces import Discrete, Box
from gym.spaces.space import Space
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
from imitation.util import util

SPACE = Discrete(4)
OBS_SHAPE = (1,)
PLACEHOLDER = np.empty(OBS_SHAPE)
SPACE = Box(-1, 1, shape=(1,))
PLACEHOLDER = np.empty(SPACE.shape)

BUFFER_SIZE = 20
K = 4
Expand All @@ -22,30 +21,27 @@


def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
all_observations = rng.random((BUFFER_SIZE, VENVS, *OBS_SHAPE))
all_observations = rng.random((BUFFER_SIZE, VENVS) + SPACE.shape)

reward_fn = PebbleStateEntropyReward(Mock(), K)
reward_fn.on_replay_buffer_initialized(
replay_buffer_mock(
ReplayBufferView(all_observations, lambda: slice(None)),
OBS_SHAPE,
SPACE,
)
)

# Act
observations = th.rand((BATCH_SIZE, *OBS_SHAPE))
observations = th.rand((BATCH_SIZE, *SPACE.shape))
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Assert
expected = util.compute_state_entropy(
observations,
all_observations.reshape(-1, *OBS_SHAPE),
all_observations.reshape(-1, *SPACE.shape),
K,
)
expected_normalized = reward_fn.entropy_stats.normalize(
th.as_tensor(expected),
).numpy()
np.testing.assert_allclose(reward, expected_normalized)
np.testing.assert_allclose(reward, expected, rtol=0.005, atol=0.005)


def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
Expand All @@ -55,11 +51,11 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
m.side_effect = lambda obs, all_obs, k: obs

reward_fn = PebbleStateEntropyReward(Mock(), K)
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
all_observations = np.empty((BUFFER_SIZE, VENVS, *SPACE.shape))
reward_fn.on_replay_buffer_initialized(
replay_buffer_mock(
ReplayBufferView(all_observations, lambda: slice(None)),
OBS_SHAPE,
SPACE,
)
)

Expand Down Expand Up @@ -97,7 +93,7 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin
reward_fn.unsupervised_exploration_finish()

# Act
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
observations = np.ones((BATCH_SIZE, *SPACE.shape))
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Assert
Expand All @@ -111,23 +107,23 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin


def test_pebble_entropy_reward_can_pickle():
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
all_observations = np.empty((BUFFER_SIZE, VENVS, *SPACE.shape))
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))

obs1 = np.random.rand(VENVS, *OBS_SHAPE)
obs1 = np.random.rand(VENVS, *SPACE.shape)
reward_fn = PebbleStateEntropyReward(reward_fn_stub, K)
reward_fn.on_replay_buffer_initialized(replay_buffer_mock(replay_buffer, OBS_SHAPE))
reward_fn.on_replay_buffer_initialized(replay_buffer_mock(replay_buffer, SPACE))
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Act
pickled = pickle.dumps(reward_fn)
reward_fn_deserialized = pickle.loads(pickled)
reward_fn_deserialized.on_replay_buffer_initialized(
replay_buffer_mock(replay_buffer, OBS_SHAPE)
replay_buffer_mock(replay_buffer, SPACE)
)

# Assert
obs2 = np.random.rand(VENVS, *OBS_SHAPE)
obs2 = np.random.rand(VENVS, *SPACE.shape)
expected_result = reward_fn(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
actual_result = reward_fn_deserialized(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
np.testing.assert_allclose(actual_result, expected_result)
Expand All @@ -137,8 +133,9 @@ def reward_fn_stub(state, action, next_state, done):
return state


def replay_buffer_mock(buffer_view: ReplayBufferView, obs_shape: tuple) -> Mock:
replay_buffer_mock = Mock()
replay_buffer_mock.buffer_view = buffer_view
replay_buffer_mock.obs_shape = obs_shape
return replay_buffer_mock
def replay_buffer_mock(buffer_view: ReplayBufferView, obs_space: Space) -> Mock:
mock = Mock()
mock.buffer_view = buffer_view
mock.observation_space = obs_space
mock.action_space = SPACE
return mock

0 comments on commit 50577b0

Please sign in to comment.