Skip to content

Commit

Permalink
log
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Dec 20, 2024
1 parent 3cdc3a8 commit c0bc747
Showing 1 changed file with 33 additions and 80 deletions.
113 changes: 33 additions & 80 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def make_inputs_require_grad(module, input, output):
self.model_adapter_name = model_adapter_name
self.ref_adapter_name = ref_adapter_name

# Get the reference model
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model or False:
Expand All @@ -328,48 +329,20 @@ def make_inputs_require_grad(module, input, output):
else:
self.ref_model = create_reference_model(model)

if processing_class is None:
raise ValueError(
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
)
if args.max_length is None:
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
" it will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if args.max_length is not None:
max_length = args.max_length

if args.max_prompt_length is None:
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
if args.max_prompt_length is not None:
max_prompt_length = args.max_prompt_length

max_completion_length = None

if data_collator is None:
data_collator = DataCollatorForUnpairedPreference(pad_token_id=processing_class.pad_token_id)

# Disable dropout if needed
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

self.max_length = max_length
# Define the data collator
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
if data_collator is None:
data_collator = DataCollatorForUnpairedPreference(pad_token_id=self.padding_value)

self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
self.max_prompt_length = max_prompt_length
self.truncation_mode = args.truncation_mode
self.max_completion_length = max_completion_length
self.processing_class = processing_class
self.max_length = args.max_length

# metric
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Expand All @@ -388,7 +361,7 @@ def make_inputs_require_grad(module, input, output):
# issued.
model.warnings_issued["estimate_tokens"] = True

# 4. Handle the dataset - UNCOMMENT WHEN _prepare_dataset READY
# Dataset preparation
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
Expand All @@ -399,7 +372,7 @@ def make_inputs_require_grad(module, input, output):
else:
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")

# Calculate dataset desirability balance
# Calculate dataset desirability balance and display warning if weights are not in an ideal range
num_desirable = max(sum(train_dataset["label"]), 1)
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary

Expand Down Expand Up @@ -780,22 +753,28 @@ def get_batch_loss_metrics(
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()

if all_num_chosen > 0:
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
metrics["logps/chosen_sum"] = (
self.accelerator.gather(model_output["chosen_logps"].nansum()).nansum().item()
metrics["rewards/chosen"] = (
self.accelerator.gather(chosen_rewards.nansum()).nansum().item() / all_num_chosen
)
metrics["logps/chosen"] = (
self.accelerator.gather(model_output["chosen_logps"].nansum()).nansum().item() / all_num_chosen
)
metrics["logits/chosen"] = (
self.accelerator.gather(model_output["sum_chosen_logits"]).nansum().item() / all_num_chosen
)
metrics["logits/chosen_sum"] = self.accelerator.gather(model_output["sum_chosen_logits"]).nansum().item()
metrics["count/chosen"] = all_num_chosen

if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = (
self.accelerator.gather(model_output["rejected_logps"].nansum()).nansum().item()
metrics["rewards/rejected"] = (
self.accelerator.gather(rejected_rewards.nansum()).nansum().item() / all_num_rejected
)
metrics["logits/rejected_sum"] = (
self.accelerator.gather(model_output["sum_rejected_logits"]).nansum().item()
metrics["logps/rejected"] = (
self.accelerator.gather(model_output["rejected_logps"].nansum()).nansum().item() / all_num_rejected
)
metrics["count/rejected"] = all_num_rejected
metrics["logits/rejected"] = (
self.accelerator.gather(model_output["sum_rejected_logits"]).nansum().item() / all_num_rejected
)

metrics["rewards/margins"] = metrics["rewards/chosen"] - metrics["rewards/rejected"]

loss = losses.nanmean()

Expand Down Expand Up @@ -845,7 +824,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor])
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.processing_class.pad_token_id,
pad_token_id=self.padding_value,
)

# if ref_output in batch use that otherwise use the reference model
Expand All @@ -859,21 +838,21 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor])
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.processing_class.pad_token_id,
pad_token_id=self.padding_value,
)
else:
ref_output = self.ref_model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.processing_class.pad_token_id,
pad_token_id=self.padding_value,
)

policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
policy_output = pad_to_length(policy_output, self.max_length, self.padding_value)
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)

ref_output = pad_to_length(ref_output, self.max_length, self.processing_class.pad_token_id)
ref_output = pad_to_length(ref_output, self.max_length, self.padding_value)
ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True)

return policy_output_decoded, ref_output_decoded
Expand Down Expand Up @@ -971,37 +950,11 @@ def evaluation_loop(
return initial_output

def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`dict[str, float]`):
The values to log.
start_time (`float` or `None`, *optional*, defaults to `None`):
Start time of the training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
logs[key] = sum(metrics) / len(metrics)
del self._stored_metrics[train_eval]

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
Expand Down

0 comments on commit c0bc747

Please sign in to comment.