Skip to content

Commit

Permalink
Enabled generalizing build_inner_model in ComposerHFCausalLM (#1450)
Browse files Browse the repository at this point in the history
  • Loading branch information
gupta-abhay authored Aug 13, 2024
1 parent 7c47e70 commit 67f1498
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from llmfoundry.metrics import (
DEFAULT_CAUSAL_LM_EVAL_METRICS,
Expand Down Expand Up @@ -193,6 +194,8 @@ def build_inner_model(
config_overrides: Dict[str, Any],
load_in_8bit: bool,
pretrained: bool,
model_cls: Union[_BaseAutoModelClass,
PreTrainedModel] = AutoModelForCausalLM,
prepare_for_fsdp: bool = False,
) -> Union[PreTrainedModel, 'PeftModel']:
"""Builds the inner model for the ComposerHFCausalLM.
Expand All @@ -207,7 +210,8 @@ def build_inner_model(
config_overrides (Dict[str, Any]): The configuration overrides.
load_in_8bit (bool): Whether to load in 8-bit.
pretrained (bool): Whether the model is pretrained.
prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: False.
model_cls (Union[Type, Type[PreTrainedModel]]): HF class for models. Default: ``AutoModelForCausalLM``.
prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: ``False``.
Returns:
Union[PreTrainedModel, 'PeftModel']: The built inner model.
Expand All @@ -231,6 +235,14 @@ def build_inner_model(
+ 'Please `pip install llm-foundry[gpu]`.',
)

if not (
hasattr(model_cls, 'from_pretrained') and
hasattr(model_cls, 'from_config')
):
raise AttributeError(
f'{model_cls=} is missing `from_pretrained` and `from_config` support.',
)

# Hugging Face copies the modules into the
# transformers modules cache. On particular systems, this operation seems to cause contention between
# the different processes. To avoid this contention, we first create the config and generation config on local rank
Expand Down Expand Up @@ -280,7 +292,7 @@ def build_inner_model(
with init_empty_weights(include_buffers=False):
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
AutoModelForCausalLM.from_pretrained(
model_cls.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
Expand All @@ -290,7 +302,7 @@ def build_inner_model(
)
else:
with init_empty_weights(include_buffers=False):
AutoModelForCausalLM.from_config(
model_cls.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
Expand All @@ -301,7 +313,7 @@ def build_inner_model(
# initialize the model on the correct device
if resolved_init_device == 'cpu':
if pretrained:
model = AutoModelForCausalLM.from_pretrained(
model = model_cls.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
Expand All @@ -310,7 +322,7 @@ def build_inner_model(
config=config,
)
else:
model = AutoModelForCausalLM.from_config(
model = model_cls.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
Expand All @@ -321,7 +333,7 @@ def build_inner_model(
'Setting cfg.pretrained=True is not supported when init_device="meta".',
)
with init_empty_weights(include_buffers=False):
model = AutoModelForCausalLM.from_config(
model = model_cls.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
Expand Down

0 comments on commit 67f1498

Please sign in to comment.