Skip to content

Commit

Permalink
remove redunant call to eval and train (#2372)
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif authored Nov 20, 2024
1 parent 066fc37 commit bb0afc2
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _generate_completions(
batch_size: int = 1,
) -> List[str]:
"""
Generates completions for a list of pre-formatted prompts.
Generates completions for a list of pre-formatted prompts from the given model.
Args:
prompts (List[str]): A list of input prompts for which completions are to be generated.
Expand All @@ -68,7 +68,6 @@ def _generate_completions(
"""
completions = []
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
unwrapped_model.eval()
for idx in range(0, len(prompts), batch_size):
batch = prompts[idx : idx + batch_size]
tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device)
Expand All @@ -81,7 +80,6 @@ def _generate_completions(
generation = generation[len(prompt) :]
completion = tokenizer.decode(generation, skip_special_tokens=True)
completions.append(completion)
unwrapped_model.train()
return completions


Expand Down

0 comments on commit bb0afc2

Please sign in to comment.