Skip to content

Commit

Permalink
Add Phi-3 Support (#4014)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 authored Jul 11, 2024
1 parent 830c3f0 commit cab0ad6
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
30 changes: 28 additions & 2 deletions ludwig/schema/model_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ludwig.schema.trainer import ECDTrainerConfig
from ludwig.types import HyperoptConfigDict, ModelConfigDict
from ludwig.utils.data_utils import get_sanitized_feature_name
from ludwig.utils.llm_utils import _PHI_MODELS, get_context_len
from ludwig.utils.llm_utils import get_context_len

if TYPE_CHECKING:
from ludwig.schema.model_types.base import ModelConfig
Expand Down Expand Up @@ -323,6 +323,9 @@ def set_llm_parameters(config: "ModelConfig") -> None:
# PEFT PR: https://github.com/huggingface/peft/pull/1375
_set_phi2_target_modules(config)

# HACK(Arnav): Set Phi-3 target modules when using LoRA
_set_phi3_target_modules(config)

# HACK(Arnav): Set Gemma target modules when using LoRA
# GitHub issue: https://github.com/ludwig-ai/ludwig/issues/3937
# PEFT PR: https://github.com/huggingface/peft/pull/1499
Expand Down Expand Up @@ -441,7 +444,11 @@ def _set_mixtral_target_modules(config: "ModelConfig") -> None:
def _set_phi2_target_modules(config: "ModelConfig") -> None:
"""If the base model is Phi-2, LoRA is enabled and the target modules are not set, set the target modules to
maximize performance."""
if config.base_model not in _PHI_MODELS:
if config.base_model not in {
"microsoft/phi-1",
"microsoft/phi-1_5",
"microsoft/phi-2",
}:
return

if not config.adapter:
Expand All @@ -456,6 +463,25 @@ def _set_phi2_target_modules(config: "ModelConfig") -> None:
config.adapter.target_modules = target_modules


def _set_phi3_target_modules(config: "ModelConfig") -> None:
if config.base_model not in {
"microsoft/Phi-3-mini-4k-instruct",
"microsoft/Phi-3-mini-128k-instruct",
}:
return

if not config.adapter:
return

if config.adapter.type != "lora" or config.adapter.target_modules:
return

target_modules = ["qkv_proj", "o_proj", "gate_up_proj", "down_proj"]

logger.info(f"Setting adapter target modules to {target_modules} for Phi-3 base model with LoRA adapter.")
config.adapter.target_modules = target_modules


def _set_gemma_target_modules(config: "ModelConfig") -> None:
"""If the base model is Gemma, LoRA is enabled and the target modules are not set, set the target modules to
maximize performance."""
Expand Down
14 changes: 1 addition & 13 deletions ludwig/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,7 @@

FALLBACK_CONTEXT_LEN = 2048

_PHI_MODELS = {
"susnato/phi-1_dev",
"susnato/phi-1_5_dev",
"susnato/phi-2",
"microsoft/phi-1",
"microsoft/phi-1_5",
"microsoft/phi-2",
}

_MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION = set()
# Phi models don't support "device_map='auto'" at model load time as of transformers 4.37.0.
_MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION.update(_PHI_MODELS)


@default_retry(tries=8)
Expand All @@ -52,8 +41,7 @@ def load_pretrained_from_config(
# Apply quantization configuration at model load time
load_kwargs["torch_dtype"] = getattr(torch, config_obj.quantization.bnb_4bit_compute_dtype)
load_kwargs["quantization_config"] = config_obj.quantization.to_bitsandbytes()
if config_obj.base_model not in _MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION:
load_kwargs["device_map"] = "auto"
load_kwargs["device_map"] = "auto"

if transformers_436:
load_kwargs["attn_implementation"] = "eager"
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ torchaudio
torchtext
torchvision
pydantic<2.0
transformers>=4.39.0
transformers>=4.42.3
tifffile
imagecodecs
tokenizers>=0.15
Expand Down
2 changes: 1 addition & 1 deletion requirements_llm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ faiss-cpu
accelerate
loralib

peft>=0.9.0
peft>=0.10.0

0 comments on commit cab0ad6

Please sign in to comment.