diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index d5ef2435f9..a822bb02ab 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -65,10 +65,8 @@ def __init__(self, om_model_config: Union[DictConfig, nn.Module], tokenizer: PreTrainedTokenizerBase): # set up training and eval metrics - train_metrics = [ - LanguageCrossEntropy(), - LanguagePerplexity(), - ] + use_train_metrics = om_model_config.get('use_train_metrics', True) + train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] if use_train_metrics else [] eval_metrics = [ LanguageCrossEntropy(), LanguagePerplexity(), diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b1dff15398..389bf6883d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -694,7 +694,8 @@ def __init__( hf_config = MPTConfig.from_dict(resolved_om_model_config) model = MPTForCausalLM(hf_config) - train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] + use_train_metrics = resolved_om_model_config.get('use_train_metrics', True) + train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] if use_train_metrics else [] eval_metrics = [ LanguageCrossEntropy(), LanguagePerplexity(),