Skip to content

Commit

Permalink
changed mutable typing, deprecation version, and backwards compatibil…
Browse files Browse the repository at this point in the history
…ity for model cls
  • Loading branch information
jdchang1 committed Aug 20, 2024
1 parent 50c9ddf commit d3aae1a
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 25 deletions.
12 changes: 6 additions & 6 deletions llmfoundry/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
metrics.register('language_perplexity', func=LanguagePerplexity)
metrics.register('masked_accuracy', func=MaskedAccuracy)

DEFAULT_CAUSAL_LM_TRAIN_METRICS = [
DEFAULT_CAUSAL_LM_TRAIN_METRICS = (
'language_cross_entropy',
'language_perplexity',
'token_accuracy',
]
)

DEFAULT_CAUSAL_LM_EVAL_METRICS = [
DEFAULT_CAUSAL_LM_EVAL_METRICS = (
'language_cross_entropy',
'language_perplexity',
'token_accuracy',
Expand All @@ -51,12 +51,12 @@
'mc_expected_calibration_error',
'mc_accuracy',
'qa_accuracy',
]
)

DEFAULT_ENC_DEC_METRICS = [
DEFAULT_ENC_DEC_METRICS = (
'language_cross_entropy',
'masked_accuracy',
]
)

__all__ = [
'TokenAccuracy',
Expand Down
30 changes: 16 additions & 14 deletions llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class BaseHuggingFaceModel(HuggingFaceModel):

model_cls: Union[_BaseAutoModelClass,
PreTrainedModel] = AutoModelForCausalLM
default_train_metrics: list = []
default_eval_metrics: list = []
default_train_metrics: tuple = ()
default_eval_metrics: tuple = ()

def __init__(
self,
Expand Down Expand Up @@ -193,13 +193,13 @@ def build_metrics(
"""
from llmfoundry.utils.builders import build_metric

train_metric_names = cls.default_train_metrics + (
train_metric_names = list(cls.default_train_metrics) + (
additional_train_metrics or []
)
train_metrics = [
build_metric(metric, {}) for metric in train_metric_names
] if use_train_metrics else []
eval_metric_names = cls.default_eval_metrics + (
eval_metric_names = list(cls.default_eval_metrics) + (
additional_eval_metrics or []
)
eval_metrics = [
Expand All @@ -220,8 +220,8 @@ def build_inner_model(
config_overrides: dict[str, Any],
load_in_8bit: bool,
pretrained: bool,
model_cls: Union[_BaseAutoModelClass,
PreTrainedModel] = AutoModelForCausalLM,
model_cls: Optional[Union[_BaseAutoModelClass,
PreTrainedModel]] = None,
prepare_for_fsdp: bool = False,
) -> Union[PreTrainedModel, 'PeftModel']:
"""Builds the inner model for the ComposerHFCausalLM.
Expand Down Expand Up @@ -260,12 +260,14 @@ def build_inner_model(
+ 'Please `pip install llm-foundry[gpu]`.',
)

model_cls = cls.model_cls if model_cls is None else model_cls

if not (
hasattr(cls.model_cls, 'from_pretrained') and
hasattr(cls.model_cls, 'from_config')
hasattr(model_cls, 'from_pretrained') and
hasattr(model_cls, 'from_config')
):
raise AttributeError(
f'{cls.model_cls=} is missing `from_pretrained` and `from_config` support.',
f'{model_cls=} is missing `from_pretrained` and `from_config` support.',
)

# Hugging Face copies the modules into the
Expand Down Expand Up @@ -314,7 +316,7 @@ def build_inner_model(
with init_empty_weights(include_buffers=False):
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
cls.model_cls.from_pretrained(
model_cls.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
Expand All @@ -324,7 +326,7 @@ def build_inner_model(
)
else:
with init_empty_weights(include_buffers=False):
cls.model_cls.from_config(
model_cls.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
Expand All @@ -335,7 +337,7 @@ def build_inner_model(
# initialize the model on the correct device
if resolved_init_device == 'cpu':
if pretrained:
model = cls.model_cls.from_pretrained(
model = model_cls.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
Expand All @@ -344,7 +346,7 @@ def build_inner_model(
config=config,
)
else:
model = cls.model_cls.from_config(
model = model_cls.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
Expand All @@ -355,7 +357,7 @@ def build_inner_model(
'Setting cfg.pretrained=True is not supported when init_device="meta".',
)
with init_empty_weights(include_buffers=False):
model = cls.model_cls.from_config(
model = model_cls.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class ComposerHFCausalLM(BaseHuggingFaceModel):

model_cls: Union[_BaseAutoModelClass,
PreTrainedModel] = AutoModelForCausalLM
default_train_metrics: list = DEFAULT_CAUSAL_LM_TRAIN_METRICS
default_eval_metrics: list = DEFAULT_CAUSAL_LM_EVAL_METRICS
default_train_metrics: tuple = DEFAULT_CAUSAL_LM_TRAIN_METRICS
default_eval_metrics: tuple = DEFAULT_CAUSAL_LM_EVAL_METRICS

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/hf/hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class ComposerHFT5(BaseHuggingFaceModel):

model_cls: Union[_BaseAutoModelClass,
PreTrainedModel] = AutoModelForSeq2SeqLM
default_train_metrics: list = DEFAULT_ENC_DEC_METRICS
default_eval_metrics: list = []
default_train_metrics: tuple = DEFAULT_ENC_DEC_METRICS
default_eval_metrics: tuple = []

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
warnings.warn(
VersionedDeprecationWarning(
'`HuggingFaceModelWithFSDP` is deprecated. In the future please use `BaseHuggingFaceModel`.',
remove_version='0.12.0',
remove_version='0.13.0',
),
)
super().__init__(
Expand Down

0 comments on commit d3aae1a

Please sign in to comment.