Skip to content

Commit

Permalink
added multi-dataset tests, linting
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Sep 25, 2023
1 parent 783e365 commit 87a92bf
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 40 deletions.
30 changes: 12 additions & 18 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ def validate_config(cfg: DictConfig):
eval_loader = cfg.eval_loader
if isinstance(eval_loader, ListConfig):
for loader in eval_loader:
if loader.label is None:
raise ValueError(
'When specifying multiple evaluation datasets, each one must include the \
`label` attribute.')
loaders.append(loader)
else:
loaders.append(cfg.eval_loader)
loaders.append(eval_loader)
for loader in loaders:
if loader.name == 'text':
if cfg.model.name in ['hf_prefix_lm', 'hf_t5']:
Expand Down Expand Up @@ -471,25 +475,15 @@ def main(cfg: DictConfig) -> Trainer:
evaluators = []
eval_loaders = []
if eval_loader_config is not None:
if isinstance(eval_loader_config, ListConfig):
for eval_config in eval_loader_config:
if eval_config.label is None:
raise ValueError(
'When specifying multiple evaluation datasets, each one must include the \
`label` attribute.')
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader = Evaluator(
label=f'eval/{eval_config.label}',
dataloader=eval_dataloader,
metric_names=[], # we will add these after model is created
)
eval_loaders.append(eval_loader)
else:
eval_dataloader = build_dataloader(eval_loader_config, tokenizer,
is_multi_eval = isinstance(eval_loader_config, ListConfig)
eval_configs = eval_loader_config if is_multi_eval else [
eval_loader_config
]
for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader = Evaluator(
label='eval',
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
metric_names=[], # we will add these after model is created
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data_prep_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def test_download_script_from_api():

def test_json_script_from_api():
# test calling it directly
path = os.path.join(os.getcwd(), 'my-copy-c4-3')
path = os.path.join(os.getcwd(), 'my-copy-arxiv-1')
shutil.rmtree(path, ignore_errors=True)
main_json(
Namespace(
**{
'path': 'scripts/data_prep/example_data/arxiv.jsonl',
'out_root': './my-copy-c4-3',
'out_root': './my-copy-arxiv-1',
'compression': None,
'split': 'train',
'concat_tokens': None,
Expand Down
9 changes: 4 additions & 5 deletions tests/test_train_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,12 @@ def test_no_label_multiple_eval_datasets(self, cfg: DictConfig) -> None:
cfg.train_loader.dataset.local = data_local
# Set up multiple eval datasets
first_eval_loader = cfg.eval_loader
first_eval_loader.label = 'eval_1'
first_eval_loader.dataset.local = data_local
second_eval_loader = copy.deepcopy(first_eval_loader)
cfg.eval_loader = om.create([first_eval_loader, second_eval_loader])
for loader in cfg.eval_loader:
loader.dataset.local = data_local
# Set the first eval dataloader to have no label
cfg.eval_loader[0].label = None
first_eval_loader.label = None
second_eval_loader.label = 'eval_1'
cfg.eval_loader = om.create([first_eval_loader, second_eval_loader])
with pytest.raises(ValueError) as exception_info:
main(cfg)
assert str(
Expand Down
57 changes: 42 additions & 15 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
sys.path.append(repo_dir)

from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402
from scripts.data_prep.convert_dataset_json import \
main as main_json # noqa: E402
from scripts.train.train import main # noqa: E402


Expand Down Expand Up @@ -53,6 +55,30 @@ def create_c4_dataset_xsmall(prefix: str) -> str:
return c4_dir


def create_arxiv_dataset(prefix: str) -> str:
"""Creates an arxiv dataset."""
arxiv_dir = os.path.join(os.getcwd(), f'my-copy-arxiv-{prefix}')
shutil.rmtree(arxiv_dir, ignore_errors=True)
downloaded_split = 'train'

main_json(
Namespace(
**{
'path': 'data_prep/example_data/arxiv.jsonl',
'out_root': arxiv_dir,
'compression': None,
'split': downloaded_split,
'concat_tokens': None,
'bos_text': None,
'eos_text': None,
'no_wrap': False,
'num_workers': None
}))

assert os.path.exists(arxiv_dir)
return arxiv_dir


def gpt_tiny_cfg(dataset_name: str, device: str):
"""Create gpt tiny cfg."""
conf_path: str = os.path.join(repo_dir,
Expand Down Expand Up @@ -154,14 +180,17 @@ def test_train_gauntlet(set_correct_cwd: Any):


def test_train_multi_eval(set_correct_cwd: Any):
"""Test training run with a small dataset."""
dataset_name = create_c4_dataset_xsmall('cpu-gauntlet')
test_cfg = gpt_tiny_cfg(dataset_name, 'cpu')
"""Test training run with multiple eval datasets."""
c4_dataset_name = create_c4_dataset_xsmall('multi-eval')
test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu')
# Set up multiple eval dataloaders
first_eval_loader = test_cfg.eval_loader
first_eval_loader.label = 'eval_1'
first_eval_loader.label = 'c4'
# Create second eval dataloader using the arxiv dataset.
second_eval_loader = copy.deepcopy(first_eval_loader)
second_eval_loader.label = 'eval_2'
arxiv_dataset_name = create_arxiv_dataset('multi-eval')
second_eval_loader.data_local = arxiv_dataset_name
second_eval_loader.label = 'arxiv'
test_cfg.eval_loader = om.create([first_eval_loader, second_eval_loader])
test_cfg.eval_subset_num_batches = 1 # -1 to evaluate on all batches

Expand All @@ -179,23 +208,21 @@ def test_train_multi_eval(set_correct_cwd: Any):
print(inmemorylogger.data.keys())

# Checks for first eval dataloader
assert 'metrics/eval/eval_1/LanguageCrossEntropy' in inmemorylogger.data.keys(
)
assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys()
assert isinstance(
inmemorylogger.data['metrics/eval/eval_1/LanguageCrossEntropy'], list)
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list)
assert len(
inmemorylogger.data['metrics/eval/eval_1/LanguageCrossEntropy'][-1]) > 0
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0
assert isinstance(
inmemorylogger.data['metrics/eval/eval_1/LanguageCrossEntropy'][-1],
tuple)
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple)

# Checks for second eval dataloader
assert 'metrics/eval/eval_2/LanguageCrossEntropy' in inmemorylogger.data.keys(
assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys(
)
assert isinstance(
inmemorylogger.data['metrics/eval/eval_2/LanguageCrossEntropy'], list)
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list)
assert len(
inmemorylogger.data['metrics/eval/eval_2/LanguageCrossEntropy'][-1]) > 0
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0
assert isinstance(
inmemorylogger.data['metrics/eval/eval_2/LanguageCrossEntropy'][-1],
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1],
tuple)

0 comments on commit 87a92bf

Please sign in to comment.