Skip to content

Commit

Permalink
Allow train.py-like config for eval.py (#1351)
Browse files Browse the repository at this point in the history
* Allow model key in eval script

* Compatibility

* pre-commit fix

* Fix load_path

* fix

* Refactor as a config transform

* formatting

* fix

* fix pyright

* fix

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
josejg and mvpatel2000 authored Jul 23, 2024
1 parent d812f20 commit eb41a6e
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions llmfoundry/command_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,54 @@ def evaluate_model(
return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df)


def allow_toplevel_keys(cfg: Dict[str, Any]) -> Dict[str, Any]:
"""Transform the config to allow top-level keys for model configuration.
This function allows users to use the 'train.py' syntax in 'eval.py'.
It converts a config with top-level 'model', 'tokenizer', and (optionally) 'load_path' keys
into the nested 'models' list format required by 'eval.py'.
Input config format (train.py style):
```yaml
model:
<model_kwargs>
load_path: /path/to/checkpoint
tokenizer:
<tokenizer_kwargs>
```
Output config format (eval.py style):
```yaml
models:
- model:
<model_kwargs>
tokenizer:
<tokenizer_kwargs>
load_path: /path/to/checkpoint
```
"""
if 'model' in cfg:
if 'models' in cfg:
raise ValueError(
'Please specify either model or models in the config, not both',
)
default_name = cfg.get('model').get('name') # type: ignore
model_cfg = {
'model': cfg.pop('model'),
'tokenizer': cfg.pop('tokenizer', None),
'model_name': cfg.pop('model_name', default_name),
}
if 'tokenizer' not in model_cfg or model_cfg['tokenizer'] is None:
raise ValueError(
'When specifying model, "tokenizer" must be provided in the config',
)
if 'load_path' in cfg:
model_cfg['load_path'] = cfg.pop('load_path')
cfg['models'] = [model_cfg]

return cfg


def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]:
# Run user provided code if specified
for code_path in cfg.get('code_paths', []):
Expand All @@ -184,6 +232,7 @@ def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]:
cfg,
EvalConfig,
EVAL_CONFIG_KEYS,
transforms=[allow_toplevel_keys],
icl_tasks_required=True,
)

Expand Down

0 comments on commit eb41a6e

Please sign in to comment.