Skip to content

Commit

Permalink
om
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Jan 23, 2024
1 parent b269d2c commit 52efb19
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 13 deletions.
15 changes: 3 additions & 12 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union, cast
from typing import Any, Dict, Optional, Union

from composer.callbacks import CheckpointSaver
from composer.core import Callback, Event, State, Time, TimeUnit
Expand All @@ -18,8 +18,6 @@
RUN_NAME_ENV_VAR)
from composer.utils import dist
from composer.utils.misc import create_interval_scheduler
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

from mcli import Run, RunConfig, create_run, get_run

Expand All @@ -37,7 +35,7 @@
'eval_gauntlet',
'eval_loader',
'fsdp_config',
'eval_subset_num_batches', # converted to subset_num_batches
'eval_subset_num_batches',
'icl_subset_num_batches',
'loggers',
'precision',
Expand Down Expand Up @@ -129,10 +127,6 @@ def get_eval_parameters(
subset_keys[key] = parameters[key]
looking_for.remove(key)

if 'eval_subset_num_batches' in subset_keys:
subset_keys['subset_num_batches'] = subset_keys.pop(
'eval_subset_num_batches')

if looking_for:
raise Exception(
f'Missing the following required parameters for async eval: {looking_for}'
Expand Down Expand Up @@ -186,10 +180,7 @@ def validate_interval(interval: Union[str, int, Time],
def validate_eval_run_config(
eval_run_config: Optional[Dict[str, Any]]) -> Dict[str, Any]:

if isinstance(eval_run_config, DictConfig):
parsed_run_config = om.to_container(eval_run_config)
run_config = cast(Dict[str, Any], parsed_run_config)
elif eval_run_config is None:
if eval_run_config is None:
return {}
else:
run_config = eval_run_config.copy()
Expand Down
3 changes: 2 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def main(cfg: DictConfig) -> Trainer:
callback_configs: Optional[DictConfig] = pop_config(cfg,
'callbacks',
must_exist=False,
default_value=None)
default_value=None,
convert=True)
algorithm_configs: Optional[DictConfig] = pop_config(cfg,
'algorithms',
must_exist=False,
Expand Down

0 comments on commit 52efb19

Please sign in to comment.