Skip to content

Commit

Permalink
☄️ Add support for Comet experiment management SDK integration (#2462)
Browse files Browse the repository at this point in the history
* Added support for Comet URL integration into model cards created by trainers.

* Moved `get_comet_experiment_url()` into utils.py

* Updated Comet badge in the model card to use PNG image instead of text.

* Fixed bug related to running PPO example during model saving. The error as following: 'GPTNeoXForCausalLM' object has no attribute 'policy'. Introduced guard check that attribute `policy` exists.

* Implemented utility method to handle logging of tabular data to the Comet experiment.

* Implemented logging of the completions table to Comet by `PPOTrainer`.

* Implemented logging of the completions table to Comet by `WinRateCallback`.

* Implemented logging of the completions table to Comet by `RLOOTrainer` and `RewardTrainer`.

* Restored line to the main branch version.

* Moved Comet related utility methods into `trainer/utils.py` to resolve merge conflict with master branch,

* Update trl/trainer/utils.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Implemented raising of `ModuleNotFoundError` error when logging table to Comet if `comet-ml` is not installed.

* import comet with other imports

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Dec 13, 2024
1 parent cd7156f commit 6d4ed07
Show file tree
Hide file tree
Showing 20 changed files with 162 additions and 21 deletions.
3 changes: 3 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def test_full(self):
dataset_name="username/my_dataset",
tags=["trl", "trainer-tag"],
wandb_url="https://wandb.ai/username/project_id/runs/abcd1234",
comet_url="https://www.comet.com/username/project_id/experiment_id",
trainer_name="My Trainer",
trainer_citation="@article{my_trainer, ...}",
paper_title="My Paper",
Expand All @@ -158,6 +159,7 @@ def test_full(self):
self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text)
self.assertIn("datasets: username/my_dataset", card_text)
self.assertIn("](https://wandb.ai/username/project_id/runs/abcd1234)", card_text)
self.assertIn("](https://www.comet.com/username/project_id/experiment_id", card_text)
self.assertIn("My Trainer", card_text)
self.assertIn("```bibtex\n@article{my_trainer, ...}\n```", card_text)
self.assertIn("[My Paper](https://huggingface.co/papers/1234.56789)", card_text)
Expand All @@ -170,6 +172,7 @@ def test_val_none(self):
dataset_name=None,
tags=[],
wandb_url=None,
comet_url=None,
trainer_name="My Trainer",
trainer_citation=None,
paper_title=None,
Expand Down
3 changes: 2 additions & 1 deletion trl/templates/lm_model_card.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ print(output["generated_text"])

## Training procedure

{% if wandb_url %}[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>]({{ wandb_url }}){% endif %}
{% if wandb_url %}[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>]({{ wandb_url }}){% endif %}
{% if comet_url %}[<img src="https://raw.githubusercontent.com/comet-ml/comet-examples/master/logo/comet_badge.png" alt="Visualize in Comet" width="135" height="20"/>]({{ comet_url }}){% endif %}

This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.

Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/alignprop_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ..models import DDPOStableDiffusionPipeline
from . import AlignPropConfig, BaseTrainer
from .utils import generate_model_card
from .utils import generate_model_card, get_comet_experiment_url


if is_wandb_available():
Expand Down Expand Up @@ -438,6 +438,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="AlignProp",
trainer_citation=citation,
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
RunningMoments,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -1514,6 +1515,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="BCO",
trainer_citation=citation,
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
Expand Down
59 changes: 46 additions & 13 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import os
from typing import Optional, Union
from typing import List, Optional, Union

import pandas as pd
import torch
Expand Down Expand Up @@ -42,6 +42,7 @@
from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
from .utils import log_table_to_comet_experiment


if is_deepspeed_available():
Expand Down Expand Up @@ -199,6 +200,16 @@ def on_train_end(self, args, state, control, **kwargs):
self.current_step = None


def _win_rate_completions_df(
state: TrainerState, prompts: List[str], completions: List[str], winner_indices: List[str]
) -> pd.DataFrame:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions, winner_indices))
# Split completions from reference model and policy
split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"])


class WinRateCallback(TrainerCallback):
"""
A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference.
Expand Down Expand Up @@ -311,15 +322,26 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
import wandb

if wandb.run is not None:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions, winner_indices))
# Split completions from referenece model and policy
split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
df = pd.DataFrame(
split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
wandb.log({"win_rate_completions": wandb.Table(dataframe=df)})

if "comet_ml" in args.report_to:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
log_table_to_comet_experiment(
name="win_rate_completions.csv",
table=df,
)

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# At every evaluation step, we generate completions for the model and compare them with the reference
# completions that have been generated at the beginning of training. We then compute the win rate and log it to
Expand Down Expand Up @@ -363,15 +385,26 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
import wandb

if wandb.run is not None:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions, winner_indices))
# Split completions from referenece model and policy
split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
df = pd.DataFrame(
split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
wandb.log({"win_rate_completions": wandb.Table(dataframe=df)})

if "comet_ml" in args.report_to:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
log_table_to_comet_experiment(
name="win_rate_completions.csv",
table=df,
)


class LogCompletionsCallback(WandbCallback):
r"""
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
add_eos_token_if_needed,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -1052,6 +1053,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="CPO",
trainer_citation=citation,
paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ..models import DDPOStableDiffusionPipeline
from . import BaseTrainer, DDPOConfig
from .utils import PerPromptStatTracker, generate_model_card
from .utils import PerPromptStatTracker, generate_model_card, get_comet_experiment_url


if is_wandb_available():
Expand Down Expand Up @@ -641,6 +641,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="DDPO",
trainer_citation=citation,
paper_title="Training Diffusion Models with Reinforcement Learning",
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
cap_exp,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad,
pad_to_length,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -1483,6 +1484,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="DPO",
trainer_citation=citation,
paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
Expand Down
9 changes: 8 additions & 1 deletion trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@
from ..models.utils import unwrap_model_for_generation
from .gkd_config import GKDConfig
from .sft_trainer import SFTTrainer
from .utils import DataCollatorForChatML, disable_dropout_in_model, empty_cache, generate_model_card
from .utils import (
DataCollatorForChatML,
disable_dropout_in_model,
empty_cache,
generate_model_card,
get_comet_experiment_url,
)


if is_deepspeed_available():
Expand Down Expand Up @@ -378,6 +384,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="GKD",
trainer_citation=citation,
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/iterative_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from transformers.utils import is_peft_available

from ..core import PPODecorators
from .utils import generate_model_card
from .utils import generate_model_card, get_comet_experiment_url


if is_peft_available():
Expand Down Expand Up @@ -434,6 +434,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Iterative SFT",
)

Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
DPODataCollatorWithPadding,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -1526,6 +1527,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="KTO",
trainer_citation=citation,
paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
Expand Down
10 changes: 9 additions & 1 deletion trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@
from .judges import BasePairwiseJudge
from .nash_md_config import NashMDConfig
from .online_dpo_trainer import OnlineDPOTrainer
from .utils import SIMPLE_CHAT_TEMPLATE, empty_cache, generate_model_card, get_reward, truncate_right
from .utils import (
SIMPLE_CHAT_TEMPLATE,
empty_cache,
generate_model_card,
get_comet_experiment_url,
get_reward,
truncate_right,
)


if is_apex_available():
Expand Down Expand Up @@ -500,6 +507,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Nash-MD",
trainer_citation=citation,
paper_title="Nash Learning from Human Feedback",
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
disable_dropout_in_model,
empty_cache,
generate_model_card,
get_comet_experiment_url,
get_reward,
prepare_deepspeed,
truncate_right,
Expand Down Expand Up @@ -734,6 +735,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Online DPO",
trainer_citation=citation,
paper_title="Direct Language Model Alignment from Online AI Feedback",
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
add_eos_token_if_needed,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -1077,6 +1078,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="ORPO",
trainer_citation=citation,
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
first_true_indices,
forward,
generate_model_card,
get_comet_experiment_url,
get_reward,
log_table_to_comet_experiment,
peft_module_casting_to_bf16,
prepare_deepspeed,
print_rich_table,
Expand Down Expand Up @@ -727,6 +729,12 @@ def generate_completions(self, sampling: bool = False):
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})

if "comet_ml" in args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=df,
)

def create_model_card(
self,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -774,6 +782,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="PPO",
trainer_citation=citation,
paper_title="Fine-Tuning Language Models from Human Preferences",
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
compute_accuracy,
decode_and_strip_padding,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
print_rich_table,
)

Expand Down Expand Up @@ -359,6 +361,12 @@ def visualize_samples(self, num_print_samples: int):
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})

if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=df,
)

def create_model_card(
self,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -398,6 +406,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Reward",
)

Expand Down
Loading

0 comments on commit 6d4ed07

Please sign in to comment.