From 3a89ee19c6f204f742849b4f4cda24c0e3e47f63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 21 Dec 2024 13:22:59 +0000 Subject: [PATCH] nash --- trl/trainer/nash_md_config.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/trl/trainer/nash_md_config.py b/trl/trainer/nash_md_config.py index dadad01f03..c8395fd136 100644 --- a/trl/trainer/nash_md_config.py +++ b/trl/trainer/nash_md_config.py @@ -31,7 +31,14 @@ class NashMDConfig(OnlineDPOConfig): epochs. """ - mixture_coef: list[float] = field(default_factory=lambda: [0.5]) + mixture_coef: list[float] = field( + default_factory=lambda: [0.5], + metadata={ + "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided " + "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the " + "rest of the epochs." + }, + ) def __post_init__(self): super().__post_init__()