diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 8018a2844c..2ed2363819 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional from transformers import TrainingArguments @@ -48,3 +48,4 @@ class RewardConfig(TrainingArguments): dataset_num_proc: Optional[int] = None center_rewards_coefficient: Optional[float] = None remove_unused_columns: bool = False + model_init_kwargs: Optional[dict[str, Any]] = None diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 79b237b9e7..08d45c061a 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -56,7 +56,7 @@ if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if is_wandb_available(): import wandb @@ -144,13 +144,48 @@ def __init__( raise ValueError( "You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once." ) + + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + torch_dtype = model_init_kwargs.get("torch_dtype") + if torch_dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(torch_dtype, str) and torch_dtype != "auto": + torch_dtype = getattr(torch, torch_dtype) + if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + ) + model_init_kwargs["torch_dtype"] = torch_dtype + if not is_peft_available() and peft_config is not None: raise ValueError( "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" ) elif is_peft_available() and peft_config is not None: + if not isinstance(peft_config, PeftConfig): + raise ValueError( + "If you want to use the PeftModel, you need to pass a valid PeftConfig object to the RewardTrainer." + f" and you passed a {type(peft_config)}." + ) + if not isinstance(model, PeftModel): - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): + is_sharded_qlora = False + # Below is to support QLoRA + FSDP / DS-Zero3 - one should never call + # peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing + # QLoRA + FSDP / DS-Zero3 + if getattr(model, "is_loaded_in_4bit", False): + for _, param in model.named_parameters(): + if param.__class__.__name__ == "Params4bit": + is_sharded_qlora = param.data.device.type in {"cpu", "meta"} + break + + if getattr(model, "is_loaded_in_8bit", False) or ( + getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora): _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( inspect.signature(prepare_model_for_kbit_training).parameters )