diff --git a/tests/test_utils.py b/tests/test_utils.py
index a1cabcfc19..df513a025d 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -147,6 +147,7 @@ def test_full(self):
dataset_name="username/my_dataset",
tags=["trl", "trainer-tag"],
wandb_url="https://wandb.ai/username/project_id/runs/abcd1234",
+ comet_url="https://www.comet.com/username/project_id/experiment_id",
trainer_name="My Trainer",
trainer_citation="@article{my_trainer, ...}",
paper_title="My Paper",
@@ -158,6 +159,7 @@ def test_full(self):
self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text)
self.assertIn("datasets: username/my_dataset", card_text)
self.assertIn("](https://wandb.ai/username/project_id/runs/abcd1234)", card_text)
+ self.assertIn("](https://www.comet.com/username/project_id/experiment_id", card_text)
self.assertIn("My Trainer", card_text)
self.assertIn("```bibtex\n@article{my_trainer, ...}\n```", card_text)
self.assertIn("[My Paper](https://huggingface.co/papers/1234.56789)", card_text)
@@ -170,6 +172,7 @@ def test_val_none(self):
dataset_name=None,
tags=[],
wandb_url=None,
+ comet_url=None,
trainer_name="My Trainer",
trainer_citation=None,
paper_title=None,
diff --git a/trl/templates/lm_model_card.md b/trl/templates/lm_model_card.md
index 316c5d829e..dbe7c1a7b2 100644
--- a/trl/templates/lm_model_card.md
+++ b/trl/templates/lm_model_card.md
@@ -20,7 +20,8 @@ print(output["generated_text"])
## Training procedure
-{% if wandb_url %}[]({{ wandb_url }}){% endif %}
+{% if wandb_url %}[]({{ wandb_url }}){% endif %}
+{% if comet_url %}[]({{ comet_url }}){% endif %}
This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.
diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py
index d03696c091..c9cd718c44 100644
--- a/trl/trainer/alignprop_trainer.py
+++ b/trl/trainer/alignprop_trainer.py
@@ -26,7 +26,7 @@
from ..models import DDPOStableDiffusionPipeline
from . import AlignPropConfig, BaseTrainer
-from .utils import generate_model_card
+from .utils import generate_model_card, get_comet_experiment_url
if is_wandb_available():
@@ -438,6 +438,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="AlignProp",
trainer_citation=citation,
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py
index dc267de882..c2d58ab3f2 100644
--- a/trl/trainer/bco_trainer.py
+++ b/trl/trainer/bco_trainer.py
@@ -59,6 +59,7 @@
RunningMoments,
disable_dropout_in_model,
generate_model_card,
+ get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
@@ -1514,6 +1515,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="BCO",
trainer_citation=citation,
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py
index 7e25366431..db1f29040b 100644
--- a/trl/trainer/callbacks.py
+++ b/trl/trainer/callbacks.py
@@ -13,7 +13,7 @@
# limitations under the License.
import os
-from typing import Optional, Union
+from typing import List, Optional, Union
import pandas as pd
import torch
@@ -42,6 +42,7 @@
from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
+from .utils import log_table_to_comet_experiment
if is_deepspeed_available():
@@ -199,6 +200,16 @@ def on_train_end(self, args, state, control, **kwargs):
self.current_step = None
+def _win_rate_completions_df(
+ state: TrainerState, prompts: List[str], completions: List[str], winner_indices: List[str]
+) -> pd.DataFrame:
+ global_step = [str(state.global_step)] * len(prompts)
+ data = list(zip(global_step, prompts, completions, winner_indices))
+ # Split completions from reference model and policy
+ split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
+ return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"])
+
+
class WinRateCallback(TrainerCallback):
"""
A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference.
@@ -311,15 +322,26 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
import wandb
if wandb.run is not None:
- global_step = [str(state.global_step)] * len(prompts)
- data = list(zip(global_step, prompts, completions, winner_indices))
- # Split completions from referenece model and policy
- split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
- df = pd.DataFrame(
- split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]
+ df = _win_rate_completions_df(
+ state=state,
+ prompts=prompts,
+ completions=completions,
+ winner_indices=winner_indices,
)
wandb.log({"win_rate_completions": wandb.Table(dataframe=df)})
+ if "comet_ml" in args.report_to:
+ df = _win_rate_completions_df(
+ state=state,
+ prompts=prompts,
+ completions=completions,
+ winner_indices=winner_indices,
+ )
+ log_table_to_comet_experiment(
+ name="win_rate_completions.csv",
+ table=df,
+ )
+
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# At every evaluation step, we generate completions for the model and compare them with the reference
# completions that have been generated at the beginning of training. We then compute the win rate and log it to
@@ -363,15 +385,26 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
import wandb
if wandb.run is not None:
- global_step = [str(state.global_step)] * len(prompts)
- data = list(zip(global_step, prompts, completions, winner_indices))
- # Split completions from referenece model and policy
- split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
- df = pd.DataFrame(
- split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]
+ df = _win_rate_completions_df(
+ state=state,
+ prompts=prompts,
+ completions=completions,
+ winner_indices=winner_indices,
)
wandb.log({"win_rate_completions": wandb.Table(dataframe=df)})
+ if "comet_ml" in args.report_to:
+ df = _win_rate_completions_df(
+ state=state,
+ prompts=prompts,
+ completions=completions,
+ winner_indices=winner_indices,
+ )
+ log_table_to_comet_experiment(
+ name="win_rate_completions.csv",
+ table=df,
+ )
+
class LogCompletionsCallback(WandbCallback):
r"""
diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py
index 49b2849766..2998d534cf 100644
--- a/trl/trainer/cpo_trainer.py
+++ b/trl/trainer/cpo_trainer.py
@@ -55,6 +55,7 @@
add_eos_token_if_needed,
disable_dropout_in_model,
generate_model_card,
+ get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
@@ -1052,6 +1053,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="CPO",
trainer_citation=citation,
paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py
index a3b6813b4c..01a1d0d5c5 100644
--- a/trl/trainer/ddpo_trainer.py
+++ b/trl/trainer/ddpo_trainer.py
@@ -27,7 +27,7 @@
from ..models import DDPOStableDiffusionPipeline
from . import BaseTrainer, DDPOConfig
-from .utils import PerPromptStatTracker, generate_model_card
+from .utils import PerPromptStatTracker, generate_model_card, get_comet_experiment_url
if is_wandb_available():
@@ -641,6 +641,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="DDPO",
trainer_citation=citation,
paper_title="Training Diffusion Models with Reinforcement Learning",
diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py
index 7ed0ac387f..3c4c7771b2 100644
--- a/trl/trainer/dpo_trainer.py
+++ b/trl/trainer/dpo_trainer.py
@@ -60,6 +60,7 @@
cap_exp,
disable_dropout_in_model,
generate_model_card,
+ get_comet_experiment_url,
pad,
pad_to_length,
peft_module_casting_to_bf16,
@@ -1483,6 +1484,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="DPO",
trainer_citation=citation,
paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py
index 173a8d6107..be48f1925b 100644
--- a/trl/trainer/gkd_trainer.py
+++ b/trl/trainer/gkd_trainer.py
@@ -42,7 +42,13 @@
from ..models.utils import unwrap_model_for_generation
from .gkd_config import GKDConfig
from .sft_trainer import SFTTrainer
-from .utils import DataCollatorForChatML, disable_dropout_in_model, empty_cache, generate_model_card
+from .utils import (
+ DataCollatorForChatML,
+ disable_dropout_in_model,
+ empty_cache,
+ generate_model_card,
+ get_comet_experiment_url,
+)
if is_deepspeed_available():
@@ -378,6 +384,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="GKD",
trainer_citation=citation,
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py
index 9621822eb8..e1ba8bee21 100644
--- a/trl/trainer/iterative_sft_trainer.py
+++ b/trl/trainer/iterative_sft_trainer.py
@@ -36,7 +36,7 @@
from transformers.utils import is_peft_available
from ..core import PPODecorators
-from .utils import generate_model_card
+from .utils import generate_model_card, get_comet_experiment_url
if is_peft_available():
@@ -434,6 +434,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="Iterative SFT",
)
diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py
index 3aa3396e39..fb0b39bbe9 100644
--- a/trl/trainer/kto_trainer.py
+++ b/trl/trainer/kto_trainer.py
@@ -58,6 +58,7 @@
DPODataCollatorWithPadding,
disable_dropout_in_model,
generate_model_card,
+ get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
@@ -1526,6 +1527,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="KTO",
trainer_citation=citation,
paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py
index a37f4bb170..1d714e2c1d 100644
--- a/trl/trainer/nash_md_trainer.py
+++ b/trl/trainer/nash_md_trainer.py
@@ -40,7 +40,14 @@
from .judges import BasePairwiseJudge
from .nash_md_config import NashMDConfig
from .online_dpo_trainer import OnlineDPOTrainer
-from .utils import SIMPLE_CHAT_TEMPLATE, empty_cache, generate_model_card, get_reward, truncate_right
+from .utils import (
+ SIMPLE_CHAT_TEMPLATE,
+ empty_cache,
+ generate_model_card,
+ get_comet_experiment_url,
+ get_reward,
+ truncate_right,
+)
if is_apex_available():
@@ -500,6 +507,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="Nash-MD",
trainer_citation=citation,
paper_title="Nash Learning from Human Feedback",
diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py
index c014ce1e13..dd34d19af0 100644
--- a/trl/trainer/online_dpo_trainer.py
+++ b/trl/trainer/online_dpo_trainer.py
@@ -58,6 +58,7 @@
disable_dropout_in_model,
empty_cache,
generate_model_card,
+ get_comet_experiment_url,
get_reward,
prepare_deepspeed,
truncate_right,
@@ -734,6 +735,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="Online DPO",
trainer_citation=citation,
paper_title="Direct Language Model Alignment from Online AI Feedback",
diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py
index e6f148f90f..f94522923b 100644
--- a/trl/trainer/orpo_trainer.py
+++ b/trl/trainer/orpo_trainer.py
@@ -59,6 +59,7 @@
add_eos_token_if_needed,
disable_dropout_in_model,
generate_model_card,
+ get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
@@ -1077,6 +1078,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="ORPO",
trainer_citation=citation,
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py
index 3827caae30..51897eeb44 100644
--- a/trl/trainer/ppo_trainer.py
+++ b/trl/trainer/ppo_trainer.py
@@ -60,7 +60,9 @@
first_true_indices,
forward,
generate_model_card,
+ get_comet_experiment_url,
get_reward,
+ log_table_to_comet_experiment,
peft_module_casting_to_bf16,
prepare_deepspeed,
print_rich_table,
@@ -727,6 +729,12 @@ def generate_completions(self, sampling: bool = False):
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
+ if "comet_ml" in args.report_to:
+ log_table_to_comet_experiment(
+ name="completions.csv",
+ table=df,
+ )
+
def create_model_card(
self,
model_name: Optional[str] = None,
@@ -774,6 +782,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="PPO",
trainer_citation=citation,
paper_title="Fine-Tuning Language Models from Human Preferences",
diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py
index 3157900b83..109d8a47cf 100644
--- a/trl/trainer/reward_trainer.py
+++ b/trl/trainer/reward_trainer.py
@@ -48,6 +48,8 @@
compute_accuracy,
decode_and_strip_padding,
generate_model_card,
+ get_comet_experiment_url,
+ log_table_to_comet_experiment,
print_rich_table,
)
@@ -359,6 +361,12 @@ def visualize_samples(self, num_print_samples: int):
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
+ if "comet_ml" in self.args.report_to:
+ log_table_to_comet_experiment(
+ name="completions.csv",
+ table=df,
+ )
+
def create_model_card(
self,
model_name: Optional[str] = None,
@@ -398,6 +406,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="Reward",
)
diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py
index 6ad38d37c4..fa9634696a 100644
--- a/trl/trainer/rloo_trainer.py
+++ b/trl/trainer/rloo_trainer.py
@@ -60,7 +60,7 @@
truncate_response,
)
from .rloo_config import RLOOConfig
-from .utils import generate_model_card
+from .utils import generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment
if is_wandb_available():
@@ -556,6 +556,12 @@ def generate_completions(self, sampling: bool = False):
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
+ if "comet_ml" in args.report_to:
+ log_table_to_comet_experiment(
+ name="completions.csv",
+ table=df,
+ )
+
def create_model_card(
self,
model_name: Optional[str] = None,
@@ -606,6 +612,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="RLOO",
trainer_citation=citation,
paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index 809eff5679..3a536c50ad 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -49,6 +49,7 @@
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
generate_model_card,
+ get_comet_experiment_url,
peft_module_casting_to_bf16,
)
@@ -540,6 +541,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="SFT",
)
diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py
index fd30bea929..30e453532f 100644
--- a/trl/trainer/utils.py
+++ b/trl/trainer/utils.py
@@ -42,6 +42,7 @@
PreTrainedTokenizerBase,
TrainerState,
TrainingArguments,
+ is_comet_available,
)
from transformers.utils import (
is_peft_available,
@@ -55,6 +56,9 @@
from ..trainer.model_config import ModelConfig
+if is_comet_available():
+ import comet_ml
+
if is_peft_available():
from peft import LoraConfig, PeftConfig
@@ -1435,6 +1439,7 @@ def generate_model_card(
trainer_citation: Optional[str] = None,
paper_title: Optional[str] = None,
paper_id: Optional[str] = None,
+ comet_url: Optional[str] = None,
) -> ModelCard:
"""
Generate a `ModelCard` from a template.
@@ -1452,6 +1457,8 @@ def generate_model_card(
Tags.
wandb_url (`str` or `None`):
Weights & Biases run URL.
+ comet_url (`str` or `None`):
+ Comet experiment URL.
trainer_name (`str`):
Trainer name.
trainer_citation (`str` or `None`, defaults to `None`):
@@ -1481,6 +1488,7 @@ def generate_model_card(
hub_model_id=hub_model_id,
dataset_name=dataset_name,
wandb_url=wandb_url,
+ comet_url=comet_url,
trainer_name=trainer_name,
trainer_citation=trainer_citation,
paper_title=paper_title,
@@ -1492,3 +1500,34 @@ def generate_model_card(
tokenizers_version=version("tokenizers"),
)
return card
+
+
+def get_comet_experiment_url() -> Optional[str]:
+ """
+ If Comet integration is enabled, return the URL of the current Comet experiment; otherwise, return `None`.
+ """
+ if not is_comet_available():
+ return None
+
+ if comet_ml.get_running_experiment() is not None:
+ return comet_ml.get_running_experiment().url
+
+ return None
+
+
+def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None:
+ """
+ If Comet integration is enabled logs a table to the Comet experiment if it is currently running.
+
+ Args:
+ name (`str`):
+ Table name.
+ table (`pd.DataFrame`):
+ The Pandas DataFrame containing the table to log.
+ """
+ if not is_comet_available():
+ raise ModuleNotFoundError("The comet-ml is not installed. Please install it first: pip install comet-ml")
+
+ experiment = comet_ml.get_running_experiment()
+ if experiment is not None:
+ experiment.log_table(tabular_data=table, filename=name)
diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py
index baa580d136..1be32ab1de 100644
--- a/trl/trainer/xpo_trainer.py
+++ b/trl/trainer/xpo_trainer.py
@@ -38,7 +38,14 @@
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
from .online_dpo_trainer import OnlineDPOTrainer
-from .utils import SIMPLE_CHAT_TEMPLATE, empty_cache, generate_model_card, get_reward, truncate_right
+from .utils import (
+ SIMPLE_CHAT_TEMPLATE,
+ empty_cache,
+ generate_model_card,
+ get_comet_experiment_url,
+ get_reward,
+ truncate_right,
+)
from .xpo_config import XPOConfig
@@ -555,6 +562,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
trainer_name="XPO",
trainer_citation=citation,
paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",