Skip to content

Commit

Permalink
Add elastic resumption to regressions.py and refactor common code
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Sep 18, 2023
1 parent 477bbb1 commit 27dca35
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 112 deletions.
140 changes: 103 additions & 37 deletions .github/workflows/regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,99 @@
import os
import subprocess

from mcli import RunConfig, RunStatus, create_run, wait_for_run_status

DIR_PATH = os.path.dirname(os.path.abspath(__file__))
REGRESSIONS_DIR = os.path.join(DIR_PATH, 'regression_yamls')

from mcli import RunConfig, create_run
COMMIT_HASH = subprocess.check_output(['git', 'rev-parse',
'HEAD']).strip().decode('utf-8')
TIMESTAMP = datetime.datetime.now().strftime('%m-%d-%Y::%H:%M:%S')


def _get_regression_config(yaml_name: str) -> RunConfig:
"""Get the yaml config from regressions directory."""
return RunConfig.from_file(os.path.join(REGRESSIONS_DIR, yaml_name))


def _set_general_configs(config: RunConfig, cluster: str, wandb_entity: str,
wandb_project: str, git_repo: str, git_branch: str):
"""Set general configuration arguments."""
config.cluster = cluster
wandb_group = f'{TIMESTAMP}::{COMMIT_HASH}'
wandb_config = {
'entity': wandb_entity,
'project': wandb_project,
'group': wandb_group
}
config.parameters['loggers'] = config.parameters.get('loggers', {})
config.parameters['loggers']['wandb'] = wandb_config
config.integrations[0]['git_repo'] = git_repo
config.integrations[0]['git_branch'] = git_branch


def test_elastic_resumption(cluster: str, save_folder: str, wandb_entity: str,
wandb_project: str, git_repo: str, git_branch: str):
"""Regression test for elastic resumption."""

def create_run_and_wait(gpus: int, resume: bool, subdir: str):
config = _get_regression_config('mpt-125m-elastic-resumption.yaml')

# Add the command to train our model
composer_command = '\ncomposer train/train.py /mnt/config/parameters.yaml'
if resume:
composer_command += ' autoresume=true'
else:
composer_command += ' autoresume=false'
config.command += composer_command

# Add suffix to name
name_suffix = f'-{gpus}'
if resume:
name_suffix += '-resume'
config.name += name_suffix

def get_configs(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str,
wandb_project: str, git_repo: str, git_branch: str):
print(f'Running regression tests on {git_repo} {git_branch}.')
eval_7b_hf = RunConfig.from_file(
os.path.join(REGRESSIONS_DIR, 'eval-7b-hf.yaml'))
eval_7b_composer = RunConfig.from_file(
os.path.join(REGRESSIONS_DIR, 'eval-7b-composer.yaml'))
llama2_finetune = RunConfig.from_file(
os.path.join(REGRESSIONS_DIR, 'llama2-finetune.yaml'))
mpt_125m_chinchilla = RunConfig.from_file(
os.path.join(REGRESSIONS_DIR, 'mpt-125m-chinchilla.yaml'))
mpt_125m_sharded_resumption = RunConfig.from_file(
os.path.join(REGRESSIONS_DIR, 'mpt-125m-sharded-resumption.yaml'))
# Set other parameters
config.compute['gpus'] = gpus
config.parameters['save_folder'] = os.path.join(save_folder, subdir)
config.parameters['max_duration'] = '20ba' if resume else '10ba'

_set_general_configs(config,
cluster=cluster,
wandb_entity=wandb_entity,
wandb_project=wandb_project,
git_repo=git_repo,
git_branch=git_branch)

# Start run
run = create_run(config)
wait_for_run_status(
run,
RunStatus.COMPLETED) # Wait for the run to complete or terminate.
if run.status != RunStatus.COMPLETED:
raise Exception(
f'Failure on run {run.name}. Run status is {run.status}. ' +
'Terminating elastic resumption regression test.')

# Test 1 node => 2 node elastic resumption
subdir = f'1_to_2_node_{TIMESTAMP}_{COMMIT_HASH}'
create_run_and_wait(gpus=8, resume=False, subdir=subdir)
create_run_and_wait(gpus=16, resume=True, subdir=subdir)

# Test 2 node => 1 node elastic resumption
subdir = f'2_to_1_node_{TIMESTAMP}_{COMMIT_HASH}'
create_run_and_wait(gpus=16, resume=False, subdir=subdir)
create_run_and_wait(gpus=8, resume=True, subdir=subdir)


def test_basic(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str,
wandb_project: str, git_repo: str, git_branch: str):
eval_7b_hf = _get_regression_config('eval-7b-hf.yaml')
eval_7b_composer = _get_regression_config('eval-7b-composer.yaml')
llama2_finetune = _get_regression_config('llama2-finetune.yaml')
mpt_125m_chinchilla = _get_regression_config('mpt-125m-chinchilla.yaml')
mpt_125m_sharded_resumption = _get_regression_config(
'mpt-125m-sharded-resumption.yaml')

# make specific changes
eval_7b_composer.parameters['models'][0]['load_path'] = mpt_7b_ckpt_path
Expand All @@ -34,25 +108,14 @@ def get_configs(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str,
mpt_125m_sharded_resumption
]

commit_hash = subprocess.check_output(['git', 'rev-parse',
'HEAD']).strip().decode('utf-8')
timestamp = datetime.datetime.now().strftime('%m-%d-%Y::%H:%M:%S')
wandb_group = f'{timestamp}::{commit_hash}'

# make general changes
wandb_config = {
'entity': wandb_entity,
'project': wandb_project,
'group': wandb_group
}
for config in all_configs:
config.cluster = cluster
config.parameters['loggers'] = config.parameters.get('loggers', {})
config.parameters['loggers']['wandb'] = wandb_config
config.integrations[0]['git_repo'] = git_repo
config.integrations[0]['git_branch'] = git_branch

return all_configs, []
_set_general_configs(config,
cluster=cluster,
wandb_entity=wandb_entity,
wandb_project=wandb_project,
git_repo=git_repo,
git_branch=git_branch)
create_run(config)


if __name__ == '__main__':
Expand All @@ -61,13 +124,16 @@ def get_configs(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str,
parser.add_argument('--mpt-7b-ckpt-path', type=str)
parser.add_argument('--wandb-entity', type=str)
parser.add_argument('--wandb-project', type=str)
parser.add_argument('--remote-save-folder', type=str)
parser.add_argument('--git-repo', type=str, default='mosaicml/llm-foundry')
parser.add_argument('--git-branch', type=str, default='main')

args = parser.parse_args()

run_configs, _ = get_configs(args.cluster, args.mpt_7b_ckpt_path,
args.wandb_entity, args.wandb_project,
args.git_repo, args.git_branch)
for run_config in run_configs:
run = create_run(run_config)
print(f'Running regression tests on {args.git_repo} {args.git_branch}.')

test_basic(args.cluster, args.mpt_7b_ckpt_path, args.wandb_entity,
args.wandb_project, args.git_repo, args.git_branch)
test_elastic_resumption(args.cluster, args.remote_save_folder,
args.wandb_entity, args.wandb_project,
args.git_repo, args.git_branch)
75 changes: 0 additions & 75 deletions .github/workflows/test_elastic_resumption.py

This file was deleted.

0 comments on commit 27dca35

Please sign in to comment.