Skip to content

Commit

Permalink
#625 introduce parameter for pretraining steps
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent c2bc9dc commit c681ca3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,7 @@ def __init__(
transition_oversampling: float = 1,
initial_comparison_frac: float = 0.1,
initial_epoch_multiplier: float = 200.0,
initial_agent_pretrain_frac: float = 0.01,
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
allow_variable_horizon: bool = False,
rng: Optional[np.random.Generator] = None,
Expand Down Expand Up @@ -1542,6 +1543,9 @@ def __init__(
initial_epoch_multiplier: before agent training begins, train the reward
model for this many more epochs than usual (on fragments sampled from a
random agent).
initial_agent_pretrain_frac: fraction of total_timesteps for which the
agent will be trained without preference gathering (and reward model
training)
custom_logger: Where to log to; if None (default), creates a new logger.
allow_variable_horizon: If False (default), algorithm will raise an
exception if it detects trajectories of different length during
Expand Down Expand Up @@ -1640,6 +1644,7 @@ def __init__(
self.fragment_length = fragment_length
self.initial_comparison_frac = initial_comparison_frac
self.initial_epoch_multiplier = initial_epoch_multiplier
self.initial_agent_pretrain_frac = initial_agent_pretrain_frac
self.num_iterations = num_iterations
self.transition_oversampling = transition_oversampling
if callable(query_schedule):
Expand Down Expand Up @@ -1672,10 +1677,11 @@ def train(
preference_query_schedule = self._preference_gather_schedule(total_comparisons)
print(f"Query schedule: {preference_query_schedule}")

timesteps_per_iteration, extra_timesteps = divmod(
total_timesteps,
self.num_iterations,
)
(
agent_pretrain_timesteps,
timesteps_per_iteration,
extra_timesteps,
) = self._compute_timesteps(total_timesteps)
reward_loss = None
reward_accuracy = None

Expand Down Expand Up @@ -1752,3 +1758,13 @@ def _preference_gather_schedule(self, total_comparisons):
shares = util.oric(probs * total_comparisons)
schedule = [initial_comparisons] + shares.tolist()
return schedule

def _compute_timesteps(self, total_timesteps: int) -> Tuple[int, int, int]:
agent_pretrain_timesteps = int(
total_timesteps * self.initial_agent_pretrain_frac
)
timesteps_per_iteration, extra_timesteps = divmod(
total_timesteps - agent_pretrain_timesteps,
self.num_iterations,
)
return agent_pretrain_timesteps, timesteps_per_iteration, extra_timesteps
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def train_defaults():
initial_comparison_frac = 0.1
# fraction of sampled trajectories that will include some random actions
exploration_frac = 0.0
# fraction of total_timesteps for training before preference gathering
initial_agent_pretrain_frac = 0.05
preference_model_kwargs = {}
reward_trainer_kwargs = {
"epochs": 3,
Expand Down Expand Up @@ -153,6 +155,7 @@ def fast():
total_timesteps = 50
total_comparisons = 5
initial_comparison_frac = 0.2
initial_agent_pretrain_frac = 0.2
num_iterations = 1
fragment_length = 2
reward_trainer_kwargs = {
Expand Down

0 comments on commit c681ca3

Please sign in to comment.