Skip to content

Commit

Permalink
#625 remove ReplayBufferEntropyRewardWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent ec7b853 commit d1aae17
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 182 deletions.
4 changes: 3 additions & 1 deletion src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def __call__(

all_observations = self.replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape((-1, *state.shape[1:])) # TODO #625: fix self.obs_shape
all_observations = all_observations.reshape(
(-1, *state.shape[1:]) # TODO #625: fix self.obs_shape
)
# TODO #625: deal with the conversion back and forth between np and torch
entropies = util.compute_state_entropy(
th.tensor(state),
Expand Down
84 changes: 1 addition & 83 deletions src/imitation/policies/replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Wrapper for reward labeling for transitions sampled from a replay buffer."""

from typing import Callable
from typing import Mapping, Type
from typing import Callable, Mapping, Type

import numpy as np
from gym import spaces
Expand All @@ -10,7 +9,6 @@

from imitation.rewards.reward_function import RewardFn
from imitation.util import util
from imitation.util.networks import RunningNorm


def _samples_to_reward_fn_input(
Expand Down Expand Up @@ -143,83 +141,3 @@ def _get_samples(self):
"_get_samples() is intentionally not implemented."
"This method should not be called.",
)


class ReplayBufferEntropyRewardWrapper(ReplayBufferRewardWrapper):
"""Relabel the rewards from a ReplayBuffer, initially using entropy as reward."""

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
*,
replay_buffer_class: Type[ReplayBuffer],
reward_fn: RewardFn,
entropy_as_reward_samples: int,
k: int = 5,
**kwargs,
):
"""Builds ReplayBufferRewardWrapper.
Args:
buffer_size: Max number of elements in the buffer
observation_space: Observation space
action_space: Action space
replay_buffer_class: Class of the replay buffer.
reward_fn: Reward function for reward relabeling.
entropy_as_reward_samples: Number of samples to use entropy as the reward,
before switching to using the reward_fn for relabeling.
k: Use the k'th nearest neighbor's distance when computing state entropy.
**kwargs: keyword arguments for ReplayBuffer.
"""
# TODO should we limit by number of batches (as this does)
# or number of observations returned?
super().__init__(
buffer_size,
observation_space,
action_space,
replay_buffer_class=replay_buffer_class,
reward_fn=reward_fn,
**kwargs,
)
self.sample_count = 0
self.k = k
# TODO support n_envs > 1
self.entropy_stats = RunningNorm(1)
self.entropy_as_reward_samples = entropy_as_reward_samples

def sample(self, *args, **kwargs):
self.sample_count += 1
samples = super().sample(*args, **kwargs)
# For some reason self.entropy_as_reward_samples seems to get cleared,
# and I have no idea why.
if self.sample_count > self.entropy_as_reward_samples:
return samples
# TODO we really ought to reset the reward network once we are done w/
# the entropy based pre-training. We also have no reason to train
# or even use the reward network before then.

if self.full:
all_obs = self.observations
else:
all_obs = self.observations[: self.pos]
# super().sample() flattens the venv dimension, let's do it too
all_obs = all_obs.reshape((-1, *self.obs_shape))
entropies = util.compute_state_entropy(
samples.observations,
all_obs,
self.k,
)

# Normalize to have mean of 0 and standard deviation of 1 according to running stats
entropies = self.entropy_stats.forward(entropies)
assert entropies.shape == samples.rewards.shape

return ReplayBufferSamples(
observations=samples.observations,
actions=samples.actions,
next_observations=samples.next_observations,
dones=samples.dones,
rewards=entropies,
)
9 changes: 5 additions & 4 deletions src/imitation/scripts/common/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ def _maybe_add_relabel_buffer(
"""Use ReplayBufferRewardWrapper in rl_kwargs if relabel_reward_fn is not None."""
rl_kwargs = dict(rl_kwargs)
if relabel_reward_fn:
_buffer_kwargs = dict(reward_fn=relabel_reward_fn)
_buffer_kwargs["replay_buffer_class"] = rl_kwargs.get(
"replay_buffer_class",
buffers.ReplayBuffer,
_buffer_kwargs = dict(
reward_fn=relabel_reward_fn,
replay_buffer_class=rl_kwargs.get(
"replay_buffer_class", buffers.ReplayBuffer
),
)
rl_kwargs["replay_buffer_class"] = ReplayBufferRewardWrapper

Expand Down
95 changes: 2 additions & 93 deletions tests/policies/test_replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@
from stable_baselines3.common import buffers, off_policy_algorithm, policies
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import get_obs_shape, get_action_dim
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.save_util import load_from_pkl
from stable_baselines3.common.vec_env import DummyVecEnv

from imitation.policies.replay_buffer_wrapper import (
ReplayBufferEntropyRewardWrapper,
ReplayBufferRewardWrapper,
)
from imitation.policies.replay_buffer_wrapper import ReplayBufferRewardWrapper
from imitation.util import util


Expand Down Expand Up @@ -123,54 +119,6 @@ def test_wrapper_class(tmpdir, rng):
replay_buffer_wrapper._get_samples()


# Combine this with the above test via parameterization over the buffer class
def test_entropy_wrapper_class_no_op(tmpdir, rng):
buffer_size = 15
total_timesteps = 20
entropy_samples = 0

venv = util.make_vec_env("Pendulum-v1", n_envs=1, rng=rng)
rl_algo = sb3.SAC(
policy=sb3.sac.policies.SACPolicy,
policy_kwargs=dict(),
env=venv,
seed=42,
replay_buffer_class=ReplayBufferEntropyRewardWrapper,
replay_buffer_kwargs=dict(
replay_buffer_class=buffers.ReplayBuffer,
reward_fn=zero_reward_fn,
entropy_as_reward_samples=entropy_samples,
),
buffer_size=buffer_size,
)

rl_algo.learn(total_timesteps=total_timesteps)

buffer_path = osp.join(tmpdir, "buffer.pkl")
rl_algo.save_replay_buffer(buffer_path)
replay_buffer_wrapper = load_from_pkl(buffer_path)
replay_buffer = replay_buffer_wrapper.replay_buffer

# replay_buffer_wrapper.sample(...) should return zero-reward transitions
assert buffer_size == replay_buffer_wrapper.size() == replay_buffer.size()
assert (replay_buffer_wrapper.sample(total_timesteps).rewards == 0.0).all()
assert (replay_buffer.sample(total_timesteps).rewards != 0.0).all() # seed=42

# replay_buffer_wrapper.pos, replay_buffer_wrapper.full
assert replay_buffer_wrapper.pos == total_timesteps - buffer_size
assert replay_buffer_wrapper.full

# reset()
replay_buffer_wrapper.reset()
assert 0 == replay_buffer_wrapper.size() == replay_buffer.size()
assert replay_buffer_wrapper.pos == 0
assert not replay_buffer_wrapper.full

# to_torch()
tensor = replay_buffer_wrapper.to_torch(np.ones(42))
assert type(tensor) is th.Tensor


class ActionIsObsEnv(gym.Env):
"""Simple environment where the obs is the action."""

Expand All @@ -191,45 +139,6 @@ def reset(self):
return np.array([0])


def test_entropy_wrapper_class(tmpdir, rng):
buffer_size = 20
entropy_samples = 500
k = 4

venv = DummyVecEnv([ActionIsObsEnv])
rl_algo = sb3.SAC(
policy=sb3.sac.policies.SACPolicy,
policy_kwargs=dict(),
env=venv,
seed=42,
replay_buffer_class=ReplayBufferEntropyRewardWrapper,
replay_buffer_kwargs=dict(
replay_buffer_class=buffers.ReplayBuffer,
reward_fn=zero_reward_fn,
entropy_as_reward_samples=entropy_samples,
k=k,
),
buffer_size=buffer_size,
)

rl_algo.learn(total_timesteps=buffer_size)
initial_entropy = util.compute_state_entropy(
th.Tensor(rl_algo.replay_buffer.replay_buffer.observations),
th.Tensor(rl_algo.replay_buffer.replay_buffer.observations),
k=k,
)

rl_algo.learn(total_timesteps=entropy_samples - buffer_size)
# Expect that the entropy of our replay buffer is now higher,
# since we trained with that as the reward.
trained_entropy = util.compute_state_entropy(
th.Tensor(rl_algo.replay_buffer.replay_buffer.observations),
th.Tensor(rl_algo.replay_buffer.replay_buffer.observations),
k=k,
)
assert trained_entropy.mean() > initial_entropy.mean()


def test_replay_buffer_view_provides_buffered_observations():
space = spaces.Box(np.array([0]), np.array([5]))
n_envs = 2
Expand Down
1 change: 0 additions & 1 deletion tests/util/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,3 @@ def test_compute_state_entropy_2d():
util.compute_state_entropy(obs, all_obs, k=3),
np.sqrt(20**2 + 2**2),
)

0 comments on commit d1aae17

Please sign in to comment.