Skip to content

Commit

Permalink
Add code import to train/eval scripts (#1002)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Mar 11, 2024
1 parent 2fc5d33 commit d61c53d
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 0 deletions.
37 changes: 37 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import importlib.util
import os
from pathlib import Path
from types import ModuleType
from typing import Union

__all__ = ['import_file']


def import_file(loc: Union[str, Path]) -> ModuleType:
"""Import module from a file. Used to run arbitrary python code.
Args:
name (str): Name of module to load.
loc (str / Path): Path to the file.
Returns:
ModuleType: The module object.
"""
if not os.path.exists(loc):
raise FileNotFoundError(f'File {loc} does not exist.')

spec = importlib.util.spec_from_file_location('python_code', str(loc))

assert spec is not None
assert spec.loader is not None

module = importlib.util.module_from_spec(spec)

try:
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f'Error executing {loc}') from e
return module
11 changes: 11 additions & 0 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

install()
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.registry import import_file
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_evaluators, build_logger,
build_tokenizer)
Expand Down Expand Up @@ -188,6 +189,16 @@ def evaluate_model(


def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
# Run user provided code if specified
code_paths = pop_config(cfg,
'code_paths',
must_exist=False,
default_value=[],
convert=True)
# Import any user provided code
for code_path in code_paths:
import_file(code_path)

om.resolve(cfg)

# Create copy of config for logging
Expand Down
11 changes: 11 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.callbacks import AsyncEval
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.registry import import_file
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_algorithm, build_callback,
build_evaluators, build_logger,
Expand Down Expand Up @@ -158,6 +159,16 @@ def main(cfg: DictConfig) -> Trainer:
'torch.distributed.*_base is a private function and will be deprecated.*'
)

# Run user provided code if specified
code_paths = pop_config(cfg,
'code_paths',
must_exist=False,
default_value=[],
convert=True)
# Import any user provided code
for code_path in code_paths:
import_file(code_path)

# Check for incompatibilities between the model and data loaders
validate_config(cfg)

Expand Down
44 changes: 44 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
import pathlib

import pytest

from llmfoundry.registry import import_file


def test_registry_init_code(tmp_path: pathlib.Path):
register_code = """
import os
os.environ['TEST_ENVIRON_REGISTRY_KEY'] = 'test'
"""

with open(tmp_path / 'init_code.py', 'w') as _f:
_f.write(register_code)

import_file(tmp_path / 'init_code.py')

assert os.environ['TEST_ENVIRON_REGISTRY_KEY'] == 'test'

del os.environ['TEST_ENVIRON_REGISTRY_KEY']


def test_registry_init_code_fails(tmp_path: pathlib.Path):
register_code = """
import os
os.environ['TEST_ENVIRON_REGISTRY_KEY'] = 'test'
asdf
"""

with open(tmp_path / 'init_code.py', 'w') as _f:
_f.write(register_code)

with pytest.raises(RuntimeError, match='Error executing .*init_code.py'):
import_file(tmp_path / 'init_code.py')


def test_registry_init_code_dne(tmp_path: pathlib.Path):
with pytest.raises(FileNotFoundError, match='File .* does not exist'):
import_file(tmp_path / 'init_code.py')

0 comments on commit d61c53d

Please sign in to comment.