Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into release/v0.9.0
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Jun 7, 2024
2 parents e83c10f + db70135 commit e583620
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ADD https://raw.githubusercontent.com/mosaicml/llm-foundry/$BRANCH_NAME/setup.py
RUN rm setup.py

# Install TransformerEngine
RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=3 MAX_JOBS=3 pip install git+https://github.com/cli99/TransformerEngine.git@6b21f606f2459d49c2113d69236d68d334edeb4c
RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=3 MAX_JOBS=3 pip install git+https://github.com/NVIDIA/TransformerEngine.git@0edf30b87159e82048b5f248e4b379aebb8f364a

# Install and uninstall foundry to cache foundry requirements
RUN git clone -b $BRANCH_NAME https://github.com/mosaicml/llm-foundry.git
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@
'utils',
]

__version__ = '0.9.0.dev0'
__version__ = '0.10.0.dev0'
13 changes: 8 additions & 5 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class EvalConfig:
# Logging parameters
python_log_level: Optional[str] = 'debug'
loggers: Optional[Dict[str, Any]] = None
console_log_interval: Union[int, str] = '1ba'
log_config: bool = True

# Model/run parameters
Expand Down Expand Up @@ -180,6 +181,11 @@ class TrainConfig:
# Variables to ignore
variables: Optional[Dict[str, Any]] = None

# Fields created by `update_batch_size_info`
n_gpus: int = MISSING
device_train_batch_size: int = MISSING
device_train_grad_accum: str = MISSING


TRAIN_CONFIG_KEYS = {field.name for field in fields(TrainConfig)}

Expand Down Expand Up @@ -242,7 +248,6 @@ def make_dataclass_and_log_config(
icl_tasks_required: bool = False,
) -> Tuple[Dict[str, Any], T]:
"""Converts a DictConfig to a dataclass and creates a logged config."""
# Resolve all interpolation variables as early as possible
unstructured_config = om.to_container(cfg, resolve=True)
assert isinstance(unstructured_config, dict)
assert all(isinstance(k, str) for k in unstructured_config.keys())
Expand Down Expand Up @@ -289,11 +294,9 @@ def make_dataclass_and_log_config(
unstructured_config['variables'] = {}

for key in extraneous_keys:
warnings.warn(
f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Interpreting {key} as a variable for logging purposes. Top-level variables are deprecated and will not be supported in future releases. Please place any variables under the `variables` key.',
category=DeprecationWarning,
raise ValueError(
f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Please place any variables under the `variables` key.',
)
unstructured_config['variables'][key] = unstructured_config.pop(key)

dataclass_dict_config: DictConfig = om.structured(
dataclass_constructor(**unstructured_config),
Expand Down
1 change: 0 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,5 @@ def main(cfg: DictConfig) -> Trainer:
yaml_cfg = om.load(f)
cli_cfg = om.from_cli(args_list)
cfg = om.merge(yaml_cfg, cli_cfg)
om.resolve(cfg)
assert isinstance(cfg, DictConfig)
main(cfg)
10 changes: 9 additions & 1 deletion tests/a_scripts/eval/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
from llmfoundry.utils.config_utils import to_dict_container
from llmfoundry.utils.config_utils import EVAL_CONFIG_KEYS, to_dict_container
from scripts.eval.eval import main # noqa: E402
from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg

Expand Down Expand Up @@ -134,6 +134,14 @@ def test_loader_eval(
test_cfg.eval_interval = '1ba'
test_cfg.loggers = om.DictConfig({'inmemory': om.DictConfig({})})

# This test uses a training yaml with training-only keys present.
# We exclude these keys before calling `main` from the eval script.
allowed_keys = EVAL_CONFIG_KEYS
present_keys = set(test_cfg.keys())
keys_to_pop = present_keys.difference(allowed_keys)

[test_cfg.pop(key) for key in keys_to_pop]

trainers, eval_gauntlet_df = main(test_cfg)

assert eval_gauntlet_df is None
Expand Down
15 changes: 4 additions & 11 deletions tests/a_scripts/eval/test_eval_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import os
import warnings

import omegaconf
import pytest
Expand Down Expand Up @@ -42,12 +41,13 @@ def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None:
omegaconf.errors.InterpolationKeyError,
omegaconf.errors.MissingMandatoryValue,
TypeError,
ValueError,
)):
cfg[p + '-mispelled'] = cfg.pop(p)
main(cfg)
cfg[p] = cfg.pop(p + '-mispelled')

def test_optional_mispelled_params_raise_warning(
def test_optional_mispelled_params_raise_error(
self,
cfg: DictConfig,
) -> None:
Expand All @@ -67,15 +67,8 @@ def test_optional_mispelled_params_raise_warning(
orig_value = cfg.pop(param, None)
updated_param = param + '-mispelling'
cfg[updated_param] = orig_value
with warnings.catch_warnings(record=True) as warning_list:
try:
main(cfg)
except:
pass
assert any(
f'Unused parameter {updated_param} found in cfg.' in
str(warning.message) for warning in warning_list
)
with pytest.raises(ValueError):
main(cfg)
# restore configs.
cfg = copy.deepcopy(old_cfg)

Expand Down
18 changes: 6 additions & 12 deletions tests/a_scripts/train/test_train_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import copy
import json
import os
import warnings

import omegaconf
import pytest
Expand Down Expand Up @@ -63,7 +62,9 @@ def cfg(self, foundry_dir: str) -> DictConfig:
def test_misspelled_mandatory_params_fail(self, cfg: DictConfig) -> None:
"""Check that mandatory misspelled inputs fail to train."""
cfg.trai_loader = cfg.pop('train_loader')
with pytest.raises((omegaconf.errors.MissingMandatoryValue, TypeError)):
with pytest.raises(
(omegaconf.errors.MissingMandatoryValue, TypeError, ValueError),
):
main(cfg)

def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None:
Expand All @@ -89,7 +90,7 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None:
main(cfg)
cfg[param] = orig_param

def test_optional_misspelled_params_raise_warning(
def test_optional_misspelled_params_raise_error(
self,
cfg: DictConfig,
) -> None:
Expand All @@ -113,15 +114,8 @@ def test_optional_misspelled_params_raise_warning(
orig_value = cfg.pop(param, None)
updated_param = param + '-misspelling'
cfg[updated_param] = orig_value
with warnings.catch_warnings(record=True) as warning_list:
try:
main(cfg)
except:
pass
assert any(
f'Unused parameter {updated_param} found in cfg.' in
str(warning.message) for warning in warning_list
)
with pytest.raises(ValueError):
main(cfg)
# restore configs.
cfg = copy.deepcopy(old_cfg)

Expand Down

0 comments on commit e583620

Please sign in to comment.