Skip to content

Commit

Permalink
Add support for Phi-1 and Phi 1.5 (#3831)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 authored Dec 14, 2023
1 parent bccfb4e commit 06cc508
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
4 changes: 3 additions & 1 deletion ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ludwig.utils.error_handling_utils import default_retry
from ludwig.utils.llm_quantization_utils import convert_quantized_linear_to_linear
from ludwig.utils.llm_utils import (
_MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION,
add_left_padding,
generate_merged_ids,
get_context_len,
Expand Down Expand Up @@ -87,7 +88,8 @@ def load_pretrained_from_config(
# Apply quanitzation configuration at model load time
load_kwargs["torch_dtype"] = getattr(torch, config_obj.quantization.bnb_4bit_compute_dtype)
load_kwargs["quantization_config"] = config_obj.quantization.to_bitsandbytes()
load_kwargs["device_map"] = "auto"
if config_obj.base_model not in _MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION:
load_kwargs["device_map"] = "auto"

if config_obj.model_parameters:
# Add any model specific parameters to the load kwargs
Expand Down
3 changes: 3 additions & 0 deletions ludwig/schema/llms/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ludwig.error import ConfigValidationError
from ludwig.schema.metadata import LLM_METADATA
from ludwig.schema.metadata.parameter_metadata import convert_metadata_to_json
from ludwig.utils.llm_utils import _PHI_BASE_MODEL_MAPPING

# Maps a preset LLM name to the full slash-delimited HF path. If the user chooses a preset LLM, the preset LLM name is
# replaced with the full slash-delimited HF path using this map, after JSON validation but before config object
Expand Down Expand Up @@ -72,6 +73,8 @@ def validate(model_name: str):
return MODEL_PRESETS[model_name]
if os.path.isdir(model_name):
return model_name
if model_name in _PHI_BASE_MODEL_MAPPING:
return _PHI_BASE_MODEL_MAPPING[model_name]
try:
AutoConfig.from_pretrained(model_name)
return model_name
Expand Down
18 changes: 17 additions & 1 deletion ludwig/schema/model_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ludwig.schema.trainer import ECDTrainerConfig
from ludwig.types import HyperoptConfigDict, ModelConfigDict
from ludwig.utils.data_utils import get_sanitized_feature_name
from ludwig.utils.llm_utils import get_context_len
from ludwig.utils.llm_utils import _PHI_BASE_MODEL_MAPPING, get_context_len

if TYPE_CHECKING:
from ludwig.schema.model_types.base import ModelConfig
Expand Down Expand Up @@ -307,13 +307,29 @@ def set_llm_parameters(config: "ModelConfig") -> None:
if config.model_type != MODEL_LLM:
return

# Do an in-place replacement for Phi models since they don't work well out of the box
_replace_phi_model_with_supported_model(config)

# Set preprocessing parameters for text features for LLM model type
_set_llm_tokenizers(config)

# Set max_new_tokens in generation config to the max sequence length of the output features
_set_generation_max_new_tokens(config)


def _replace_phi_model_with_supported_model(config: "ModelConfig") -> None:
"""Replaces the phi model with a supported model that is compatible with the LLM model type."""
if config.base_model not in _PHI_BASE_MODEL_MAPPING:
return

logger.warning(
f"{config.base_model} does not work correctly out of the box since it requires running remote code."
f"Replacing {config.base_model} with {_PHI_BASE_MODEL_MAPPING[config.base_model]} as the base LLM model."
)

config.base_model = _PHI_BASE_MODEL_MAPPING[config.base_model]


def _set_llm_tokenizers(config: "ModelConfig") -> None:
"""Sets the tokenizers for the LLM model to the pretrained model name or path. This ensures that they use the
correct shared vocabulary from the tokenizer.
Expand Down
14 changes: 13 additions & 1 deletion ludwig/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@

FALLBACK_CONTEXT_LEN = 2048

# The official microsoft phi models don't work out of the box because the weights aren't compatiable with HF
# See https://github.com/huggingface/transformers/issues/28049 for more context.
_PHI_BASE_MODEL_MAPPING = {
"microsoft/phi-1": "susnato/phi-1_dev",
"microsoft/phi-1.5": "susnato/phi-1_5_dev",
}

# The susnato Phi models as of Transformers 4.36.1 don't support "device_map='auto'" at model load time.
_MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION = {"susnato/phi-1_dev", "susnato/phi-1_5_dev"}


def to_device(
model: PreTrainedModel,
Expand Down Expand Up @@ -54,11 +64,13 @@ def to_device(
model_kwargs.update(
dict(
low_cpu_mem_usage=True,
device_map="auto",
max_memory={i: "13GiB" for i in range(num_gpus)},
)
)

if config_obj.base_model not in _MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION:
model_kwargs["device_map"] = "auto"

if config_obj.quantization:
model_kwargs["quantization_config"] = config_obj.quantization.to_bitsandbytes()

Expand Down

0 comments on commit 06cc508

Please sign in to comment.