Skip to content

Commit

Permalink
#625 plug in pebble according to parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent ad8d76e commit 2ab0780
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 473 deletions.
3 changes: 3 additions & 0 deletions src/imitation/scripts/config/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ def train_defaults():

checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only)
query_schedule = "hyperbolic"
# Whether to use the PEBBLE algorithm (https://arxiv.org/pdf/2106.05091.pdf)
pebble_enabled = False


@train_preference_comparisons_ex.named_config
def pebble():
# fraction of total_timesteps for training before preference gathering
pebble_enabled = True
unsupervised_agent_pretrain_frac = 0.05
pebble_nearest_neighbor_k = 5

Expand Down
163 changes: 0 additions & 163 deletions src/imitation/scripts/config/train_preference_comparisons_pebble.py

This file was deleted.

83 changes: 65 additions & 18 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,27 @@
Can be used as a CLI script, or the `train_preference_comparisons` function
can be called directly.
"""

import functools
import pathlib
from typing import Any, Mapping, Optional, Type, Union

import numpy as np
import torch as th
from sacred.observers import FileStorageObserver
from stable_baselines3.common import type_aliases
from stable_baselines3.common import type_aliases, base_class, vec_env

from imitation.algorithms import preference_comparisons
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
from imitation.data import types
from imitation.policies import serialize
from imitation.rewards import reward_nets, reward_function
from imitation.scripts.common import common, reward
from imitation.scripts.common import rl as rl_common
from imitation.scripts.common import train
from imitation.scripts.config.train_preference_comparisons import (
train_preference_comparisons_ex,
)
from imitation.util import logger as imit_logger


def save_model(
Expand Down Expand Up @@ -57,6 +60,59 @@ def save_checkpoint(
)


@train_preference_comparisons_ex.capture
def make_reward_function(
reward_net: reward_nets.RewardNet,
*,
pebble_enabled: bool = False,
pebble_nearest_neighbor_k: Optional[int] = None,
):
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
update_stats=False,
)
if pebble_enabled:
relabel_reward_fn = PebbleStateEntropyReward(
relabel_reward_fn, pebble_nearest_neighbor_k
)
return relabel_reward_fn


@train_preference_comparisons_ex.capture
def make_agent_trajectory_generator(
venv: vec_env.VecEnv,
agent: base_class.BaseAlgorithm,
reward_net: reward_nets.RewardNet,
relabel_reward_fn: reward_function.RewardFn,
rng: np.random.Generator,
custom_logger: Optional[imit_logger.HierarchicalLogger],
*,
exploration_frac: float,
pebble_enabled: bool,
trajectory_generator_kwargs: Mapping[str, Any],
) -> preference_comparisons.AgentTrainer:
if pebble_enabled:
return preference_comparisons.PebbleAgentTrainer(
algorithm=agent,
reward_fn=relabel_reward_fn,
venv=venv,
exploration_frac=exploration_frac,
rng=rng,
custom_logger=custom_logger,
**trajectory_generator_kwargs,
)
else:
return preference_comparisons.AgentTrainer(
algorithm=agent,
reward_fn=reward_net,
venv=venv,
exploration_frac=exploration_frac,
rng=rng,
custom_logger=custom_logger,
**trajectory_generator_kwargs,
)


@train_preference_comparisons_ex.main
def train_preference_comparisons(
total_timesteps: int,
Expand All @@ -83,7 +139,6 @@ def train_preference_comparisons(
checkpoint_interval: int,
query_schedule: Union[str, type_aliases.Schedule],
unsupervised_agent_pretrain_frac: Optional[float],
pebble_nearest_neighbor_k: Optional[int],
) -> Mapping[str, Any]:
"""Train a reward model using preference comparisons.
Expand Down Expand Up @@ -146,8 +201,6 @@ def train_preference_comparisons(
unsupervised_agent_pretrain_frac: fraction of total_timesteps for which the
agent will be trained without preference gathering (and reward model
training)
pebble_nearest_neighbor_k: Parameter for state entropy computation (for PEBBLE
training only)
Returns:
Rollout statistics from trained policy.
Expand All @@ -160,10 +213,8 @@ def train_preference_comparisons(

with common.make_venv() as venv:
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
update_stats=False,
)
relabel_reward_fn = make_reward_function(reward_net)

if agent_path is None:
agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn)
else:
Expand All @@ -176,21 +227,17 @@ def train_preference_comparisons(
if trajectory_path is None:
# Setting the logger here is not necessary (PreferenceComparisons takes care
# of it automatically) but it avoids creating unnecessary loggers.
agent_trainer = preference_comparisons.AgentTrainer(
algorithm=agent,
reward_fn=reward_net,
trajectory_generator = make_agent_trajectory_generator(
venv=venv,
exploration_frac=exploration_frac,
agent=agent,
reward_net=reward_net,
relabel_reward_fn=relabel_reward_fn,
rng=rng,
custom_logger=custom_logger,
**trajectory_generator_kwargs,
)
# Stable Baselines will automatically occupy GPU 0 if it is available.
# Let's use the same device as the SB3 agent for the reward model.
reward_net = reward_net.to(agent_trainer.algorithm.device)
trajectory_generator: preference_comparisons.TrajectoryGenerator = (
agent_trainer
)
reward_net = reward_net.to(trajectory_generator.algorithm.device)
else:
if exploration_frac > 0:
raise ValueError(
Expand Down
Loading

0 comments on commit 2ab0780

Please sign in to comment.