Skip to content

Commit

Permalink
Bump Version to 0.10.0.dev0 (#1255)
Browse files Browse the repository at this point in the history
* bump version

* typo

* Update config_utils.py

These changes are necessary as the deprecation broke compatibility with `update_batch_size`.

* Update config_utils.py

fix typo

* typo

* typo I

* update tests

* typo II

* typo III

* bump composer version

* undo composer bump for seperate pr

* fix test

* fix tests II

* yolo

* tye-o

* pyrite

* we resolve later

* revert new . syntax

---------

Co-authored-by: v-chen_data <[email protected]>
Co-authored-by: Milo Cress <[email protected]>
Co-authored-by: Saaketh Narayan <[email protected]>
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
5 people authored Jun 7, 2024
1 parent 14f296c commit bea61fb
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 31 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@
'utils',
]

__version__ = '0.9.0.dev0'
__version__ = '0.10.0.dev0'
13 changes: 8 additions & 5 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class EvalConfig:
# Logging parameters
python_log_level: Optional[str] = 'debug'
loggers: Optional[Dict[str, Any]] = None
console_log_interval: Union[int, str] = '1ba'
log_config: bool = True

# Model/run parameters
Expand Down Expand Up @@ -180,6 +181,11 @@ class TrainConfig:
# Variables to ignore
variables: Optional[Dict[str, Any]] = None

# Fields created by `update_batch_size_info`
n_gpus: int = MISSING
device_train_batch_size: int = MISSING
device_train_grad_accum: str = MISSING


TRAIN_CONFIG_KEYS = {field.name for field in fields(TrainConfig)}

Expand Down Expand Up @@ -242,7 +248,6 @@ def make_dataclass_and_log_config(
icl_tasks_required: bool = False,
) -> Tuple[Dict[str, Any], T]:
"""Converts a DictConfig to a dataclass and creates a logged config."""
# Resolve all interpolation variables as early as possible
unstructured_config = om.to_container(cfg, resolve=True)
assert isinstance(unstructured_config, dict)
assert all(isinstance(k, str) for k in unstructured_config.keys())
Expand Down Expand Up @@ -289,11 +294,9 @@ def make_dataclass_and_log_config(
unstructured_config['variables'] = {}

for key in extraneous_keys:
warnings.warn(
f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Interpreting {key} as a variable for logging purposes. Top-level variables are deprecated and will not be supported in future releases. Please place any variables under the `variables` key.',
category=DeprecationWarning,
raise ValueError(
f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Please place any variables under the `variables` key.',
)
unstructured_config['variables'][key] = unstructured_config.pop(key)

dataclass_dict_config: DictConfig = om.structured(
dataclass_constructor(**unstructured_config),
Expand Down
1 change: 0 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,5 @@ def main(cfg: DictConfig) -> Trainer:
yaml_cfg = om.load(f)
cli_cfg = om.from_cli(args_list)
cfg = om.merge(yaml_cfg, cli_cfg)
om.resolve(cfg)
assert isinstance(cfg, DictConfig)
main(cfg)
10 changes: 9 additions & 1 deletion tests/a_scripts/eval/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
from llmfoundry.utils.config_utils import to_dict_container
from llmfoundry.utils.config_utils import EVAL_CONFIG_KEYS, to_dict_container
from scripts.eval.eval import main # noqa: E402
from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg

Expand Down Expand Up @@ -134,6 +134,14 @@ def test_loader_eval(
test_cfg.eval_interval = '1ba'
test_cfg.loggers = om.DictConfig({'inmemory': om.DictConfig({})})

# This test uses a training yaml with training-only keys present.
# We exclude these keys before calling `main` from the eval script.
allowed_keys = EVAL_CONFIG_KEYS
present_keys = set(test_cfg.keys())
keys_to_pop = present_keys.difference(allowed_keys)

[test_cfg.pop(key) for key in keys_to_pop]

trainers, eval_gauntlet_df = main(test_cfg)

assert eval_gauntlet_df is None
Expand Down
15 changes: 4 additions & 11 deletions tests/a_scripts/eval/test_eval_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import os
import warnings

import omegaconf
import pytest
Expand Down Expand Up @@ -42,12 +41,13 @@ def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None:
omegaconf.errors.InterpolationKeyError,
omegaconf.errors.MissingMandatoryValue,
TypeError,
ValueError,
)):
cfg[p + '-mispelled'] = cfg.pop(p)
main(cfg)
cfg[p] = cfg.pop(p + '-mispelled')

def test_optional_mispelled_params_raise_warning(
def test_optional_mispelled_params_raise_error(
self,
cfg: DictConfig,
) -> None:
Expand All @@ -67,15 +67,8 @@ def test_optional_mispelled_params_raise_warning(
orig_value = cfg.pop(param, None)
updated_param = param + '-mispelling'
cfg[updated_param] = orig_value
with warnings.catch_warnings(record=True) as warning_list:
try:
main(cfg)
except:
pass
assert any(
f'Unused parameter {updated_param} found in cfg.' in
str(warning.message) for warning in warning_list
)
with pytest.raises(ValueError):
main(cfg)
# restore configs.
cfg = copy.deepcopy(old_cfg)

Expand Down
18 changes: 6 additions & 12 deletions tests/a_scripts/train/test_train_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import copy
import json
import os
import warnings

import omegaconf
import pytest
Expand Down Expand Up @@ -63,7 +62,9 @@ def cfg(self, foundry_dir: str) -> DictConfig:
def test_misspelled_mandatory_params_fail(self, cfg: DictConfig) -> None:
"""Check that mandatory misspelled inputs fail to train."""
cfg.trai_loader = cfg.pop('train_loader')
with pytest.raises((omegaconf.errors.MissingMandatoryValue, TypeError)):
with pytest.raises(
(omegaconf.errors.MissingMandatoryValue, TypeError, ValueError),
):
main(cfg)

def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None:
Expand All @@ -89,7 +90,7 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None:
main(cfg)
cfg[param] = orig_param

def test_optional_misspelled_params_raise_warning(
def test_optional_misspelled_params_raise_error(
self,
cfg: DictConfig,
) -> None:
Expand All @@ -113,15 +114,8 @@ def test_optional_misspelled_params_raise_warning(
orig_value = cfg.pop(param, None)
updated_param = param + '-misspelling'
cfg[updated_param] = orig_value
with warnings.catch_warnings(record=True) as warning_list:
try:
main(cfg)
except:
pass
assert any(
f'Unused parameter {updated_param} found in cfg.' in
str(warning.message) for warning in warning_list
)
with pytest.raises(ValueError):
main(cfg)
# restore configs.
cfg = copy.deepcopy(old_cfg)

Expand Down

0 comments on commit bea61fb

Please sign in to comment.