From 6d4ed070f1f53a87fb3cff2eb82a56db093bccc6 Mon Sep 17 00:00:00 2001 From: Iaroslav Omelianenko Date: Fri, 13 Dec 2024 23:08:10 +0200 Subject: [PATCH] =?UTF-8?q?=E2=98=84=EF=B8=8F=20Add=20support=20for=20Come?= =?UTF-8?q?t=20=20experiment=20management=20SDK=20integration=20(#2462)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added support for Comet URL integration into model cards created by trainers. * Moved `get_comet_experiment_url()` into utils.py * Updated Comet badge in the model card to use PNG image instead of text. * Fixed bug related to running PPO example during model saving. The error as following: 'GPTNeoXForCausalLM' object has no attribute 'policy'. Introduced guard check that attribute `policy` exists. * Implemented utility method to handle logging of tabular data to the Comet experiment. * Implemented logging of the completions table to Comet by `PPOTrainer`. * Implemented logging of the completions table to Comet by `WinRateCallback`. * Implemented logging of the completions table to Comet by `RLOOTrainer` and `RewardTrainer`. * Restored line to the main branch version. * Moved Comet related utility methods into `trainer/utils.py` to resolve merge conflict with master branch, * Update trl/trainer/utils.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Implemented raising of `ModuleNotFoundError` error when logging table to Comet if `comet-ml` is not installed. * import comet with other imports --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- tests/test_utils.py | 3 ++ trl/templates/lm_model_card.md | 3 +- trl/trainer/alignprop_trainer.py | 3 +- trl/trainer/bco_trainer.py | 2 + trl/trainer/callbacks.py | 59 ++++++++++++++++++++++------ trl/trainer/cpo_trainer.py | 2 + trl/trainer/ddpo_trainer.py | 3 +- trl/trainer/dpo_trainer.py | 2 + trl/trainer/gkd_trainer.py | 9 ++++- trl/trainer/iterative_sft_trainer.py | 3 +- trl/trainer/kto_trainer.py | 2 + trl/trainer/nash_md_trainer.py | 10 ++++- trl/trainer/online_dpo_trainer.py | 2 + trl/trainer/orpo_trainer.py | 2 + trl/trainer/ppo_trainer.py | 9 +++++ trl/trainer/reward_trainer.py | 9 +++++ trl/trainer/rloo_trainer.py | 9 ++++- trl/trainer/sft_trainer.py | 2 + trl/trainer/utils.py | 39 ++++++++++++++++++ trl/trainer/xpo_trainer.py | 10 ++++- 20 files changed, 162 insertions(+), 21 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index a1cabcfc19..df513a025d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -147,6 +147,7 @@ def test_full(self): dataset_name="username/my_dataset", tags=["trl", "trainer-tag"], wandb_url="https://wandb.ai/username/project_id/runs/abcd1234", + comet_url="https://www.comet.com/username/project_id/experiment_id", trainer_name="My Trainer", trainer_citation="@article{my_trainer, ...}", paper_title="My Paper", @@ -158,6 +159,7 @@ def test_full(self): self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text) self.assertIn("datasets: username/my_dataset", card_text) self.assertIn("](https://wandb.ai/username/project_id/runs/abcd1234)", card_text) + self.assertIn("](https://www.comet.com/username/project_id/experiment_id", card_text) self.assertIn("My Trainer", card_text) self.assertIn("```bibtex\n@article{my_trainer, ...}\n```", card_text) self.assertIn("[My Paper](https://huggingface.co/papers/1234.56789)", card_text) @@ -170,6 +172,7 @@ def test_val_none(self): dataset_name=None, tags=[], wandb_url=None, + comet_url=None, trainer_name="My Trainer", trainer_citation=None, paper_title=None, diff --git a/trl/templates/lm_model_card.md b/trl/templates/lm_model_card.md index 316c5d829e..dbe7c1a7b2 100644 --- a/trl/templates/lm_model_card.md +++ b/trl/templates/lm_model_card.md @@ -20,7 +20,8 @@ print(output["generated_text"]) ## Training procedure -{% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} +{% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} +{% if comet_url %}[Visualize in Comet]({{ comet_url }}){% endif %} This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py index d03696c091..c9cd718c44 100644 --- a/trl/trainer/alignprop_trainer.py +++ b/trl/trainer/alignprop_trainer.py @@ -26,7 +26,7 @@ from ..models import DDPOStableDiffusionPipeline from . import AlignPropConfig, BaseTrainer -from .utils import generate_model_card +from .utils import generate_model_card, get_comet_experiment_url if is_wandb_available(): @@ -438,6 +438,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="AlignProp", trainer_citation=citation, paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation", diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index dc267de882..c2d58ab3f2 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -59,6 +59,7 @@ RunningMoments, disable_dropout_in_model, generate_model_card, + get_comet_experiment_url, pad_to_length, peft_module_casting_to_bf16, ) @@ -1514,6 +1515,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="BCO", trainer_citation=citation, paper_title="Binary Classifier Optimization for Large Language Model Alignment", diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 7e25366431..db1f29040b 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import Optional, Union +from typing import List, Optional, Union import pandas as pd import torch @@ -42,6 +42,7 @@ from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge +from .utils import log_table_to_comet_experiment if is_deepspeed_available(): @@ -199,6 +200,16 @@ def on_train_end(self, args, state, control, **kwargs): self.current_step = None +def _win_rate_completions_df( + state: TrainerState, prompts: List[str], completions: List[str], winner_indices: List[str] +) -> pd.DataFrame: + global_step = [str(state.global_step)] * len(prompts) + data = list(zip(global_step, prompts, completions, winner_indices)) + # Split completions from reference model and policy + split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data] + return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]) + + class WinRateCallback(TrainerCallback): """ A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference. @@ -311,15 +322,26 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: import wandb if wandb.run is not None: - global_step = [str(state.global_step)] * len(prompts) - data = list(zip(global_step, prompts, completions, winner_indices)) - # Split completions from referenece model and policy - split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data] - df = pd.DataFrame( - split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"] + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, ) wandb.log({"win_rate_completions": wandb.Table(dataframe=df)}) + if "comet_ml" in args.report_to: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + log_table_to_comet_experiment( + name="win_rate_completions.csv", + table=df, + ) + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): # At every evaluation step, we generate completions for the model and compare them with the reference # completions that have been generated at the beginning of training. We then compute the win rate and log it to @@ -363,15 +385,26 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra import wandb if wandb.run is not None: - global_step = [str(state.global_step)] * len(prompts) - data = list(zip(global_step, prompts, completions, winner_indices)) - # Split completions from referenece model and policy - split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data] - df = pd.DataFrame( - split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"] + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, ) wandb.log({"win_rate_completions": wandb.Table(dataframe=df)}) + if "comet_ml" in args.report_to: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + log_table_to_comet_experiment( + name="win_rate_completions.csv", + table=df, + ) + class LogCompletionsCallback(WandbCallback): r""" diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 49b2849766..2998d534cf 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -55,6 +55,7 @@ add_eos_token_if_needed, disable_dropout_in_model, generate_model_card, + get_comet_experiment_url, pad_to_length, peft_module_casting_to_bf16, ) @@ -1052,6 +1053,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="CPO", trainer_citation=citation, paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation", diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py index a3b6813b4c..01a1d0d5c5 100644 --- a/trl/trainer/ddpo_trainer.py +++ b/trl/trainer/ddpo_trainer.py @@ -27,7 +27,7 @@ from ..models import DDPOStableDiffusionPipeline from . import BaseTrainer, DDPOConfig -from .utils import PerPromptStatTracker, generate_model_card +from .utils import PerPromptStatTracker, generate_model_card, get_comet_experiment_url if is_wandb_available(): @@ -641,6 +641,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="DDPO", trainer_citation=citation, paper_title="Training Diffusion Models with Reinforcement Learning", diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 7ed0ac387f..3c4c7771b2 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -60,6 +60,7 @@ cap_exp, disable_dropout_in_model, generate_model_card, + get_comet_experiment_url, pad, pad_to_length, peft_module_casting_to_bf16, @@ -1483,6 +1484,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="DPO", trainer_citation=citation, paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model", diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 173a8d6107..be48f1925b 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -42,7 +42,13 @@ from ..models.utils import unwrap_model_for_generation from .gkd_config import GKDConfig from .sft_trainer import SFTTrainer -from .utils import DataCollatorForChatML, disable_dropout_in_model, empty_cache, generate_model_card +from .utils import ( + DataCollatorForChatML, + disable_dropout_in_model, + empty_cache, + generate_model_card, + get_comet_experiment_url, +) if is_deepspeed_available(): @@ -378,6 +384,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="GKD", trainer_citation=citation, paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py index 9621822eb8..e1ba8bee21 100644 --- a/trl/trainer/iterative_sft_trainer.py +++ b/trl/trainer/iterative_sft_trainer.py @@ -36,7 +36,7 @@ from transformers.utils import is_peft_available from ..core import PPODecorators -from .utils import generate_model_card +from .utils import generate_model_card, get_comet_experiment_url if is_peft_available(): @@ -434,6 +434,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="Iterative SFT", ) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 3aa3396e39..fb0b39bbe9 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -58,6 +58,7 @@ DPODataCollatorWithPadding, disable_dropout_in_model, generate_model_card, + get_comet_experiment_url, pad_to_length, peft_module_casting_to_bf16, ) @@ -1526,6 +1527,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="KTO", trainer_citation=citation, paper_title="KTO: Model Alignment as Prospect Theoretic Optimization", diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index a37f4bb170..1d714e2c1d 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -40,7 +40,14 @@ from .judges import BasePairwiseJudge from .nash_md_config import NashMDConfig from .online_dpo_trainer import OnlineDPOTrainer -from .utils import SIMPLE_CHAT_TEMPLATE, empty_cache, generate_model_card, get_reward, truncate_right +from .utils import ( + SIMPLE_CHAT_TEMPLATE, + empty_cache, + generate_model_card, + get_comet_experiment_url, + get_reward, + truncate_right, +) if is_apex_available(): @@ -500,6 +507,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="Nash-MD", trainer_citation=citation, paper_title="Nash Learning from Human Feedback", diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index c014ce1e13..dd34d19af0 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -58,6 +58,7 @@ disable_dropout_in_model, empty_cache, generate_model_card, + get_comet_experiment_url, get_reward, prepare_deepspeed, truncate_right, @@ -734,6 +735,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="Online DPO", trainer_citation=citation, paper_title="Direct Language Model Alignment from Online AI Feedback", diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index e6f148f90f..f94522923b 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -59,6 +59,7 @@ add_eos_token_if_needed, disable_dropout_in_model, generate_model_card, + get_comet_experiment_url, pad_to_length, peft_module_casting_to_bf16, ) @@ -1077,6 +1078,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="ORPO", trainer_citation=citation, paper_title="ORPO: Monolithic Preference Optimization without Reference Model", diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 3827caae30..51897eeb44 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -60,7 +60,9 @@ first_true_indices, forward, generate_model_card, + get_comet_experiment_url, get_reward, + log_table_to_comet_experiment, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, @@ -727,6 +729,12 @@ def generate_completions(self, sampling: bool = False): if wandb.run is not None: wandb.log({"completions": wandb.Table(dataframe=df)}) + if "comet_ml" in args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=df, + ) + def create_model_card( self, model_name: Optional[str] = None, @@ -774,6 +782,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="PPO", trainer_citation=citation, paper_title="Fine-Tuning Language Models from Human Preferences", diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 3157900b83..109d8a47cf 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -48,6 +48,8 @@ compute_accuracy, decode_and_strip_padding, generate_model_card, + get_comet_experiment_url, + log_table_to_comet_experiment, print_rich_table, ) @@ -359,6 +361,12 @@ def visualize_samples(self, num_print_samples: int): if wandb.run is not None: wandb.log({"completions": wandb.Table(dataframe=df)}) + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=df, + ) + def create_model_card( self, model_name: Optional[str] = None, @@ -398,6 +406,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="Reward", ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 6ad38d37c4..fa9634696a 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -60,7 +60,7 @@ truncate_response, ) from .rloo_config import RLOOConfig -from .utils import generate_model_card +from .utils import generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment if is_wandb_available(): @@ -556,6 +556,12 @@ def generate_completions(self, sampling: bool = False): if wandb.run is not None: wandb.log({"completions": wandb.Table(dataframe=df)}) + if "comet_ml" in args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=df, + ) + def create_model_card( self, model_name: Optional[str] = None, @@ -606,6 +612,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="RLOO", trainer_citation=citation, paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs", diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 809eff5679..3a536c50ad 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -49,6 +49,7 @@ ConstantLengthDataset, DataCollatorForCompletionOnlyLM, generate_model_card, + get_comet_experiment_url, peft_module_casting_to_bf16, ) @@ -540,6 +541,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="SFT", ) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index fd30bea929..30e453532f 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -42,6 +42,7 @@ PreTrainedTokenizerBase, TrainerState, TrainingArguments, + is_comet_available, ) from transformers.utils import ( is_peft_available, @@ -55,6 +56,9 @@ from ..trainer.model_config import ModelConfig +if is_comet_available(): + import comet_ml + if is_peft_available(): from peft import LoraConfig, PeftConfig @@ -1435,6 +1439,7 @@ def generate_model_card( trainer_citation: Optional[str] = None, paper_title: Optional[str] = None, paper_id: Optional[str] = None, + comet_url: Optional[str] = None, ) -> ModelCard: """ Generate a `ModelCard` from a template. @@ -1452,6 +1457,8 @@ def generate_model_card( Tags. wandb_url (`str` or `None`): Weights & Biases run URL. + comet_url (`str` or `None`): + Comet experiment URL. trainer_name (`str`): Trainer name. trainer_citation (`str` or `None`, defaults to `None`): @@ -1481,6 +1488,7 @@ def generate_model_card( hub_model_id=hub_model_id, dataset_name=dataset_name, wandb_url=wandb_url, + comet_url=comet_url, trainer_name=trainer_name, trainer_citation=trainer_citation, paper_title=paper_title, @@ -1492,3 +1500,34 @@ def generate_model_card( tokenizers_version=version("tokenizers"), ) return card + + +def get_comet_experiment_url() -> Optional[str]: + """ + If Comet integration is enabled, return the URL of the current Comet experiment; otherwise, return `None`. + """ + if not is_comet_available(): + return None + + if comet_ml.get_running_experiment() is not None: + return comet_ml.get_running_experiment().url + + return None + + +def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None: + """ + If Comet integration is enabled logs a table to the Comet experiment if it is currently running. + + Args: + name (`str`): + Table name. + table (`pd.DataFrame`): + The Pandas DataFrame containing the table to log. + """ + if not is_comet_available(): + raise ModuleNotFoundError("The comet-ml is not installed. Please install it first: pip install comet-ml") + + experiment = comet_ml.get_running_experiment() + if experiment is not None: + experiment.log_table(tabular_data=table, filename=name) diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index baa580d136..1be32ab1de 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -38,7 +38,14 @@ from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge from .online_dpo_trainer import OnlineDPOTrainer -from .utils import SIMPLE_CHAT_TEMPLATE, empty_cache, generate_model_card, get_reward, truncate_right +from .utils import ( + SIMPLE_CHAT_TEMPLATE, + empty_cache, + generate_model_card, + get_comet_experiment_url, + get_reward, + truncate_right, +) from .xpo_config import XPOConfig @@ -555,6 +562,7 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), trainer_name="XPO", trainer_citation=citation, paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",