Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 9, 2023
1 parent 9fbe7a1 commit 547ec21
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 48 deletions.
42 changes: 32 additions & 10 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

from composer.core import Callback, Event, State, Time
from composer.loggers import Logger
Expand All @@ -18,12 +18,12 @@

MAX_RUN_NAME_LENGTH = 40

# Note: train parameter names. See comments if they are different from eval
REQUIRED_PARAMS_FOR_EVAL = {
'device_eval_batch_size',
'icl_tasks', # only required for eval
'icl_tasks', # only required for eval, may not be specified in pure training
'max_seq_len',
'model', # models
'model', # converted into models
'tokenizer', # converted into models
'save_folder', # required, but used as load_path
}
OPTIONAL_PARAMS_FOR_EVAL = {
Expand All @@ -39,7 +39,9 @@


def get_run_name(previous_run_name: str, count: int) -> str:
return f'eval{count}-{previous_run_name[:MAX_RUN_NAME_LENGTH]}'
*name_without_uuid_suffix, _ = previous_run_name.split('-')
name_suffix = '-'.join(name_without_uuid_suffix)[:MAX_RUN_NAME_LENGTH]
return f'eval{count}-{name_suffix}'


def get_load_path(save_folder: str,
Expand All @@ -53,6 +55,24 @@ def get_load_path(save_folder: str,
return f'{save_folder}/{save_latest_filename}'


def get_eval_models_dict(
model: Dict[str, Any],
tokenizer: Dict[str, Any],
) -> List[Dict[str, Any]]:
name = model.get('name')

cfg_overrides = model.pop('cfg_overrides', {})
for key in cfg_overrides:
model[key] = cfg_overrides[key]

new_model = {'model_name': name, 'model': model}

if tokenizer:
new_model['tokenizer'] = tokenizer

return [new_model]


class AsyncEval(Callback):
"""Run the eval loop asynchronously as part of a MosaicML platform run.
Expand All @@ -74,6 +94,7 @@ def __init__(
self.check_interval = create_interval_scheduler(interval)
self.compute = compute
self.count = 0
self.last_launch: Optional[Time] = None

# Run these during init to fail fast in any of the error cases
self.current_run = self._get_current_run()
Expand All @@ -88,8 +109,10 @@ def __init__(
def run_event(self, event: Event, state: State, logger: Logger) -> None:
del logger
if state.get_elapsed_duration() is not None and self.check_interval(
state, event):
state, event) and self.last_launch != state.timestamp.batch:
self.launch_run()

self.last_launch = state.timestamp.batch
self.count += 1

def _get_current_run(self) -> Run:
Expand Down Expand Up @@ -134,10 +157,9 @@ def get_eval_parameters(
subset_keys.pop('save_folder'),
parameters.get('save_latest_filename', None))

# Rename the keys to match the eval script
subset_keys['models'] = [subset_keys.pop('model')]
if 'fsdp_config' in subset_keys:
subset_keys['fsdp_dict_cfg'] = subset_keys.pop('fsdp_config')
# Create new eval models list
subset_keys['models'] = get_eval_models_dict(
subset_keys.pop('model'), subset_keys.pop('tokenizer'))

subset_keys['run_name'] = get_run_name(run_name, 0)
return subset_keys
Expand Down
Empty file added tests/callbacks/__init__.py
Empty file.
99 changes: 61 additions & 38 deletions tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import patch

import pytest

from llmfoundry.callbacks import AsyncEval
from llmfoundry.callbacks.async_eval_callback import get_run_name
from llmfoundry.callbacks.async_eval_callback import AsyncEval, get_run_name
from mcli import Run, RunConfig, RunStatus

RUN_NAME = 'foo_bar'
RUN_NAME = 'foo_bar-1234'


def test_get_run_name():
a = get_run_name('foo', 0)
a = get_run_name('foo-1234', 0)
assert a == 'eval0-foo'

b = get_run_name(50 * 'foo', 1)
b = get_run_name(50 * 'foo-1234', 1)
assert b == 'eval1-foofoofoofoofoofoofoofoofoofoofoofoofoof'


Expand Down Expand Up @@ -49,31 +51,55 @@ def test_fails_when_no_run_name():
AsyncEval(interval='2ba')


BASIC_PARAMS = {
'device_eval_batch_size': 2,
'icl_tasks': 'icl_task_example',
'max_seq_len': 3,
'model': {
'name': 'model_example',
'cfg_overrides': {
'attn_config': {
'foo': 'bar'
}
}
},
'tokenizer': {
'tokenizer_example': 'tokenizer_example',
},
'save_folder': 'save_folder_example',
}


def test_get_eval_parameters():
with pytest.raises(
Exception,
match='Missing the following required parameters for async eval:'):
AsyncEval.get_eval_parameters(None, {}, RUN_NAME)

# minimal example
params = AsyncEval.get_eval_parameters(
None, {
'device_eval_batch_size': 2,
'icl_tasks': 'icl_task_example',
'max_seq_len': 3,
'model': {
'model_example': 'model_example'
},
'save_folder': 'save_folder_example',
}, RUN_NAME)
params = AsyncEval.get_eval_parameters(None, BASIC_PARAMS, RUN_NAME)
assert params == {
'device_eval_batch_size': 2,
'icl_tasks': 'icl_task_example',
'max_seq_len': 3,
'load_path': 'save_folder_example/latest-rank0.pt',
'run_name': 'eval0-foo_bar',
'device_eval_batch_size':
2,
'icl_tasks':
'icl_task_example',
'max_seq_len':
3,
'load_path':
'save_folder_example/latest-rank0.pt',
'run_name':
'eval0-foo_bar',
'models': [{
'model_example': 'model_example'
'model_name': 'model_example',
'model': {
'name': 'model_example',
'attn_config': {
'foo': 'bar'
},
},
'tokenizer': {
'tokenizer_example': 'tokenizer_example'
},
}],
}

Expand All @@ -82,13 +108,7 @@ def test_get_eval_parameters():
None,
{
# required
'device_eval_batch_size': 2,
'icl_tasks': 'icl_task_example',
'max_seq_len': 3,
'model': {
'model_example': 'model_example'
},
'save_folder': 'save_folder_example',
**BASIC_PARAMS,
# optional
'dist_timeout': 1,
'eval_gauntlet': 'eval_gauntlet_example',
Expand All @@ -113,7 +133,16 @@ def test_get_eval_parameters():
'run_name': 'eval0-foo_bar',
'dist_timeout': 1,
'models': [{
'model_example': 'model_example'
'model_name': 'model_example',
'model': {
'name': 'model_example',
'attn_config': {
'foo': 'bar'
},
},
'tokenizer': {
'tokenizer_example': 'tokenizer_example'
},
}],
'eval_gauntlet': 'eval_gauntlet_example',
'fsdp_dict_cfg': {
Expand All @@ -133,7 +162,7 @@ def test_get_eval_parameters():
FAKE_RUN = Run(
run_uid='123',
name=RUN_NAME,
image="fake-image",
image='fake-image',
status=RunStatus.RUNNING,
created_at='2021-01-01',
updated_at='2021-01-01',
Expand All @@ -142,7 +171,7 @@ def test_get_eval_parameters():
preemptible=False,
retry_on_system_failure=True,
cluster='c1z2',
gpu_type="a100",
gpu_type='a100',
gpus=16,
cpus=0,
node_count=2,
Expand All @@ -151,13 +180,7 @@ def test_get_eval_parameters():
name=RUN_NAME,
image='fake-image',
command='echo hi',
parameters={
'device_eval_batch_size': 2,
'icl_tasks': 'icl_task_example',
'max_seq_len': 3,
'model': 'model_example',
'save_folder': 'save_folder_example',
},
parameters=BASIC_PARAMS,
),
)

Expand Down

0 comments on commit 547ec21

Please sign in to comment.