Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

peft_config & is_loaded_in_4bit check added to Reward_Trainer #2427

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion trl/trainer/reward_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional

from transformers import TrainingArguments

Expand Down Expand Up @@ -48,3 +48,4 @@ class RewardConfig(TrainingArguments):
dataset_num_proc: Optional[int] = None
center_rewards_coefficient: Optional[float] = None
remove_unused_columns: bool = False
model_init_kwargs: Optional[dict[str, Any]] = None
39 changes: 37 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

if is_wandb_available():
import wandb
Expand Down Expand Up @@ -144,13 +144,48 @@ def __init__(
raise ValueError(
"You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once."
)

if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
torch_dtype = model_init_kwargs.get("torch_dtype")
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
model_init_kwargs["torch_dtype"] = torch_dtype

if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
if not isinstance(peft_config, PeftConfig):
raise ValueError(
"If you want to use the PeftModel, you need to pass a valid PeftConfig object to the RewardTrainer."
f" and you passed a {type(peft_config)}."
)

Comment on lines +170 to +175
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not isinstance(peft_config, PeftConfig):
raise ValueError(
"If you want to use the PeftModel, you need to pass a valid PeftConfig object to the RewardTrainer."
f" and you passed a {type(peft_config)}."
)

I don't think it's necessary here. The code will eventually fail if the peft config doesn't have the right type

if not isinstance(model, PeftModel):
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
is_sharded_qlora = False
# Below is to support QLoRA + FSDP / DS-Zero3 - one should never call
# peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing
# QLoRA + FSDP / DS-Zero3
if getattr(model, "is_loaded_in_4bit", False):
for _, param in model.named_parameters():
if param.__class__.__name__ == "Params4bit":
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
break

if getattr(model, "is_loaded_in_8bit", False) or (
getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora):
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
Expand Down