Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Oct 30, 2023
1 parent d79f486 commit 9e33709
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,13 @@ def __init__(
elif ref_model is None and not self.is_peft_model:
self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
elif self.is_peft_model:
self.ref_model = self.model
self.ref_model = None
else:
raise ValueError(
f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported "
f"architectures are: {SUPPORTED_ARCHITECTURES} "
)
self.optional_peft_ref_model = self.model if self.is_peft_model else if self.ref_model
self.optional_peft_ctx = (
self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
if self.is_peft_model
Expand Down Expand Up @@ -463,7 +464,7 @@ def generate(
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = self._generate_batched(
self.accelerator.unwrap_model(self.ref_model),
self.accelerator.unwrap_model(self.optional_peft_ref_model),
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
Expand All @@ -484,7 +485,7 @@ def generate(
)
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = self.accelerator.unwrap_model(self.ref_model).generate(
ref_response = self.accelerator.unwrap_model(self.optional_peft_ref_model).generate(
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
)

Expand Down Expand Up @@ -711,7 +712,7 @@ def step(
)
with self.optional_peft_ctx():
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
self.ref_model, queries, responses, model_inputs, return_logits=full_kl_penalty
self.optional_peft_ref_model, queries, responses, model_inputs, return_logits=full_kl_penalty
)

timing["time/ppo/forward_pass"] = time.time() - t
Expand Down

0 comments on commit 9e33709

Please sign in to comment.