Skip to content

Commit

Permalink
Add flag to disable train metrics (#642)
Browse files Browse the repository at this point in the history
* free mem

* lint

* lint
  • Loading branch information
mvpatel2000 authored Oct 3, 2023
1 parent a0e64ba commit 9025b83
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 4 additions & 4 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ def __init__(self, om_model_config: Union[DictConfig,
nn.Module],
tokenizer: PreTrainedTokenizerBase):
# set up training and eval metrics
train_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
]
train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
Expand All @@ -92,6 +89,9 @@ def __init__(self, om_model_config: Union[DictConfig,
'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
)

if not om_model_config.get('use_train_metrics', True):
train_metrics = []

# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,9 @@ def __init__(
hf_config = MPTConfig.from_dict(resolved_om_model_config)
model = MPTForCausalLM(hf_config)

train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
use_train_metrics = om_model_config.get('use_train_metrics', True)
train_metrics = [LanguageCrossEntropy(),
LanguagePerplexity()] if use_train_metrics else []
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
Expand Down

0 comments on commit 9025b83

Please sign in to comment.