diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index ea7f2190549..d229214118f 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -34,7 +34,7 @@ from ludwig.schema.trainer import ECDTrainerConfig from ludwig.types import HyperoptConfigDict, ModelConfigDict from ludwig.utils.data_utils import get_sanitized_feature_name -from ludwig.utils.llm_utils import _PHI_MODELS, get_context_len +from ludwig.utils.llm_utils import get_context_len if TYPE_CHECKING: from ludwig.schema.model_types.base import ModelConfig @@ -323,6 +323,9 @@ def set_llm_parameters(config: "ModelConfig") -> None: # PEFT PR: https://github.com/huggingface/peft/pull/1375 _set_phi2_target_modules(config) + # HACK(Arnav): Set Phi-3 target modules when using LoRA + _set_phi3_target_modules(config) + # HACK(Arnav): Set Gemma target modules when using LoRA # GitHub issue: https://github.com/ludwig-ai/ludwig/issues/3937 # PEFT PR: https://github.com/huggingface/peft/pull/1499 @@ -441,7 +444,11 @@ def _set_mixtral_target_modules(config: "ModelConfig") -> None: def _set_phi2_target_modules(config: "ModelConfig") -> None: """If the base model is Phi-2, LoRA is enabled and the target modules are not set, set the target modules to maximize performance.""" - if config.base_model not in _PHI_MODELS: + if config.base_model not in { + "microsoft/phi-1", + "microsoft/phi-1_5", + "microsoft/phi-2", + }: return if not config.adapter: @@ -456,6 +463,25 @@ def _set_phi2_target_modules(config: "ModelConfig") -> None: config.adapter.target_modules = target_modules +def _set_phi3_target_modules(config: "ModelConfig") -> None: + if config.base_model not in { + "microsoft/Phi-3-mini-4k-instruct", + "microsoft/Phi-3-mini-128k-instruct", + }: + return + + if not config.adapter: + return + + if config.adapter.type != "lora" or config.adapter.target_modules: + return + + target_modules = ["qkv_proj", "o_proj", "gate_up_proj", "down_proj"] + + logger.info(f"Setting adapter target modules to {target_modules} for Phi-3 base model with LoRA adapter.") + config.adapter.target_modules = target_modules + + def _set_gemma_target_modules(config: "ModelConfig") -> None: """If the base model is Gemma, LoRA is enabled and the target modules are not set, set the target modules to maximize performance.""" diff --git a/ludwig/utils/llm_utils.py b/ludwig/utils/llm_utils.py index 16e640301af..29452237e33 100644 --- a/ludwig/utils/llm_utils.py +++ b/ludwig/utils/llm_utils.py @@ -27,18 +27,7 @@ FALLBACK_CONTEXT_LEN = 2048 -_PHI_MODELS = { - "susnato/phi-1_dev", - "susnato/phi-1_5_dev", - "susnato/phi-2", - "microsoft/phi-1", - "microsoft/phi-1_5", - "microsoft/phi-2", -} - _MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION = set() -# Phi models don't support "device_map='auto'" at model load time as of transformers 4.37.0. -_MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION.update(_PHI_MODELS) @default_retry(tries=8) @@ -52,8 +41,7 @@ def load_pretrained_from_config( # Apply quantization configuration at model load time load_kwargs["torch_dtype"] = getattr(torch, config_obj.quantization.bnb_4bit_compute_dtype) load_kwargs["quantization_config"] = config_obj.quantization.to_bitsandbytes() - if config_obj.base_model not in _MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION: - load_kwargs["device_map"] = "auto" + load_kwargs["device_map"] = "auto" if transformers_436: load_kwargs["attn_implementation"] = "eager" diff --git a/requirements.txt b/requirements.txt index 6b7998d93b2..d293073e2c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ torchaudio torchtext torchvision pydantic<2.0 -transformers>=4.39.0 +transformers>=4.42.3 tifffile imagecodecs tokenizers>=0.15 diff --git a/requirements_llm.txt b/requirements_llm.txt index 655344392b2..c691bc0bac3 100644 --- a/requirements_llm.txt +++ b/requirements_llm.txt @@ -4,4 +4,4 @@ faiss-cpu accelerate loralib -peft>=0.9.0 +peft>=0.10.0