From 32bd62f777224edcff331e8a53f43cd56e3552e5 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 13 Aug 2024 12:33:03 -0400 Subject: [PATCH] allow embedding resizing passed through --- llmfoundry/models/hf/hf_causal_lm.py | 2 ++ llmfoundry/models/hf/model_wrapper.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 536cd0257d..7d944b3ab5 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -88,6 +88,7 @@ def __init__( config_overrides: Optional[Dict[str, Any]] = None, peft_config: Optional[Dict[str, Any]] = None, use_train_metrics: bool = True, + allow_embedding_resizing: bool = False, additional_train_metrics: Optional[List] = None, additional_eval_metrics: Optional[List] = None, should_save_peft_only: bool = True, @@ -130,6 +131,7 @@ def __init__( tokenizer=tokenizer, metrics=train_metrics, eval_metrics=eval_metrics, + allow_embedding_resizing=allow_embedding_resizing, init_device=init_device, peft_config=peft_config_object, should_save_peft_only=should_save_peft_only, diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index 7051986df8..013cee7085 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -38,6 +38,7 @@ def __init__( metrics: Optional[List[Metric]] = None, eval_metrics: Optional[List[Metric]] = None, shift_labels: bool = False, + allow_embedding_resizing: bool = False, init_device: Optional[str] = None, peft_config: Optional['PeftConfig'] = None, should_save_peft_only: bool = True, @@ -49,6 +50,7 @@ def __init__( metrics=metrics, eval_metrics=eval_metrics, shift_labels=shift_labels, + allow_embedding_resizing=allow_embedding_resizing, peft_config=peft_config, should_save_peft_only=should_save_peft_only, )