From c891bed3675c4b2facbc8c9eb7aea9d4423dba87 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 22 May 2024 17:54:27 -0400 Subject: [PATCH] Make config/class properties on ComposerMPTForCausalLM (#1227) --- llmfoundry/models/mpt/modeling_mpt.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cb62b462c2..f6b36a3f94 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -19,6 +19,7 @@ MutableMapping, Optional, Tuple, + Type, Union, ) @@ -1082,9 +1083,7 @@ def __init__( additional_train_metrics = additional_train_metrics or [] - model = MPTForCausalLM( - MPTConfig(use_train_metrics=use_train_metrics, **kwargs), - ) + model = self.model_class(self.config_class(**kwargs),) use_train_metrics = use_train_metrics train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + additional_train_metrics @@ -1134,6 +1133,14 @@ def __init__( f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].', ) + @property + def model_class(self) -> Type[MPTForCausalLM]: + return MPTForCausalLM + + @property + def config_class(self) -> Type[MPTConfig]: + return MPTConfig + def get_targets(self, batch: Mapping) -> torch.Tensor: targets = torch.roll(batch['labels'], shifts=-1) targets[:, -1] = -100