Skip to content

Commit

Permalink
#625 specialized PebbleAgentTrainer to distinguish from old preferenc…
Browse files Browse the repository at this point in the history
…e comparison trainer
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent 716c710 commit 152efa6
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tqdm.auto import tqdm

from imitation.algorithms import base
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
from imitation.data import rollout, types, wrappers
from imitation.data.types import (
AnyPath,
Expand Down Expand Up @@ -329,6 +330,27 @@ def logger(self, value: imit_logger.HierarchicalLogger) -> None:
self.algorithm.set_logger(self.logger)


class PebbleAgentTrainer(AgentTrainer):
"""
Specialization of AgentTrainer for PEBBLE training.
Includes unsupervised pretraining with an entropy based reward function.
"""

reward_fn: PebbleStateEntropyReward

def __init__(
self,
*,
reward_fn: PebbleStateEntropyReward,
**kwargs,
) -> None:
super().__init__(reward_fn=reward_fn, **kwargs)

def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
self.train(steps, **kwargs)
self.reward_fn.unsupervised_exploration_finish()


def _get_trajectories(
trajectories: Sequence[TrajectoryWithRew],
steps: int,
Expand Down Expand Up @@ -1705,7 +1727,9 @@ def train(
self.logger.log(
f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps"
)
self.trajectory_generator.unsupervised_pretrain(unsupervised_pretrain_timesteps)
self.trajectory_generator.unsupervised_pretrain(
unsupervised_pretrain_timesteps
)

for i, num_pairs in enumerate(preference_query_schedule):
##########################
Expand Down

0 comments on commit 152efa6

Please sign in to comment.