Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

in peft finetune, only the trainable parameters need to be saved #27825

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@
from accelerate.utils import DeepSpeedSchedulerWrapper


def _is_peft_model(model):
return is_peft_available() and isinstance(model, PeftModel)


if TYPE_CHECKING:
import optuna

Expand Down Expand Up @@ -398,13 +402,12 @@ def __init__(
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)

_is_peft_model = is_peft_available() and isinstance(model, PeftModel)
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)

# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model:
if _is_quantized_and_base_model and not _is_peft_model(model):
raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
Expand Down Expand Up @@ -619,7 +622,7 @@ def _activate_neftune(self, model):
"""
unwrapped_model = unwrap_model(model)

if is_peft_available() and isinstance(unwrapped_model, PeftModel):
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
Expand All @@ -640,7 +643,7 @@ def _deactivate_neftune(self, model):

unwrapped_model = unwrap_model(model)

if is_peft_available() and isinstance(unwrapped_model, PeftModel):
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
Expand Down Expand Up @@ -696,7 +699,7 @@ def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
model_to_inspect = self.model
if is_peft_available() and isinstance(self.model, PeftModel):
if _is_peft_model(self.model):
model_to_inspect = self.model.get_base_model()
signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys())
Expand Down Expand Up @@ -2109,7 +2112,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
self._issue_warnings_after_load(load_result)

# Load adapters following PR # 24096
elif is_peft_available() and isinstance(model, PeftModel):
elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
Expand Down Expand Up @@ -2172,7 +2175,7 @@ def _load_best_model(self):
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
else:
if is_peft_available() and isinstance(model, PeftModel):
if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
Expand Down Expand Up @@ -2448,7 +2451,13 @@ def _save_optimizer_and_scheduler(self, output_dir):
elif self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
self.model_wrapped.save_checkpoint(output_dir)
accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled:
# save fsdp specific ckpt for resuming from ckpt
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
Expand Down Expand Up @@ -2761,7 +2770,7 @@ def compute_loss(self, model, inputs, return_outputs=False):

if labels is not None:
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
Expand Down
Loading