From 87a92bfe7df11858dee09fbd0f89b79d0ef85413 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 25 Sep 2023 15:48:07 -0700 Subject: [PATCH] added multi-dataset tests, linting --- scripts/train/train.py | 30 +++++++---------- tests/test_data_prep_scripts.py | 4 +-- tests/test_train_inputs.py | 9 +++--- tests/test_training.py | 57 ++++++++++++++++++++++++--------- 4 files changed, 60 insertions(+), 40 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 73e36cf79a..b31c15467e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -36,9 +36,13 @@ def validate_config(cfg: DictConfig): eval_loader = cfg.eval_loader if isinstance(eval_loader, ListConfig): for loader in eval_loader: + if loader.label is None: + raise ValueError( + 'When specifying multiple evaluation datasets, each one must include the \ + `label` attribute.') loaders.append(loader) else: - loaders.append(cfg.eval_loader) + loaders.append(eval_loader) for loader in loaders: if loader.name == 'text': if cfg.model.name in ['hf_prefix_lm', 'hf_t5']: @@ -471,25 +475,15 @@ def main(cfg: DictConfig) -> Trainer: evaluators = [] eval_loaders = [] if eval_loader_config is not None: - if isinstance(eval_loader_config, ListConfig): - for eval_config in eval_loader_config: - if eval_config.label is None: - raise ValueError( - 'When specifying multiple evaluation datasets, each one must include the \ - `label` attribute.') - eval_dataloader = build_dataloader(eval_config, tokenizer, - device_eval_batch_size) - eval_loader = Evaluator( - label=f'eval/{eval_config.label}', - dataloader=eval_dataloader, - metric_names=[], # we will add these after model is created - ) - eval_loaders.append(eval_loader) - else: - eval_dataloader = build_dataloader(eval_loader_config, tokenizer, + is_multi_eval = isinstance(eval_loader_config, ListConfig) + eval_configs = eval_loader_config if is_multi_eval else [ + eval_loader_config + ] + for eval_config in eval_configs: + eval_dataloader = build_dataloader(eval_config, tokenizer, device_eval_batch_size) eval_loader = Evaluator( - label='eval', + label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, metric_names=[], # we will add these after model is created ) diff --git a/tests/test_data_prep_scripts.py b/tests/test_data_prep_scripts.py index 52ab42806f..4c555ea9a2 100644 --- a/tests/test_data_prep_scripts.py +++ b/tests/test_data_prep_scripts.py @@ -37,13 +37,13 @@ def test_download_script_from_api(): def test_json_script_from_api(): # test calling it directly - path = os.path.join(os.getcwd(), 'my-copy-c4-3') + path = os.path.join(os.getcwd(), 'my-copy-arxiv-1') shutil.rmtree(path, ignore_errors=True) main_json( Namespace( **{ 'path': 'scripts/data_prep/example_data/arxiv.jsonl', - 'out_root': './my-copy-c4-3', + 'out_root': './my-copy-arxiv-1', 'compression': None, 'split': 'train', 'concat_tokens': None, diff --git a/tests/test_train_inputs.py b/tests/test_train_inputs.py index a11a09bdd0..bf90f48ef0 100644 --- a/tests/test_train_inputs.py +++ b/tests/test_train_inputs.py @@ -166,13 +166,12 @@ def test_no_label_multiple_eval_datasets(self, cfg: DictConfig) -> None: cfg.train_loader.dataset.local = data_local # Set up multiple eval datasets first_eval_loader = cfg.eval_loader - first_eval_loader.label = 'eval_1' + first_eval_loader.dataset.local = data_local second_eval_loader = copy.deepcopy(first_eval_loader) - cfg.eval_loader = om.create([first_eval_loader, second_eval_loader]) - for loader in cfg.eval_loader: - loader.dataset.local = data_local # Set the first eval dataloader to have no label - cfg.eval_loader[0].label = None + first_eval_loader.label = None + second_eval_loader.label = 'eval_1' + cfg.eval_loader = om.create([first_eval_loader, second_eval_loader]) with pytest.raises(ValueError) as exception_info: main(cfg) assert str( diff --git a/tests/test_training.py b/tests/test_training.py index bfff13bfc8..f83c4ebd3c 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -17,6 +17,8 @@ sys.path.append(repo_dir) from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402 +from scripts.data_prep.convert_dataset_json import \ + main as main_json # noqa: E402 from scripts.train.train import main # noqa: E402 @@ -53,6 +55,30 @@ def create_c4_dataset_xsmall(prefix: str) -> str: return c4_dir +def create_arxiv_dataset(prefix: str) -> str: + """Creates an arxiv dataset.""" + arxiv_dir = os.path.join(os.getcwd(), f'my-copy-arxiv-{prefix}') + shutil.rmtree(arxiv_dir, ignore_errors=True) + downloaded_split = 'train' + + main_json( + Namespace( + **{ + 'path': 'data_prep/example_data/arxiv.jsonl', + 'out_root': arxiv_dir, + 'compression': None, + 'split': downloaded_split, + 'concat_tokens': None, + 'bos_text': None, + 'eos_text': None, + 'no_wrap': False, + 'num_workers': None + })) + + assert os.path.exists(arxiv_dir) + return arxiv_dir + + def gpt_tiny_cfg(dataset_name: str, device: str): """Create gpt tiny cfg.""" conf_path: str = os.path.join(repo_dir, @@ -154,14 +180,17 @@ def test_train_gauntlet(set_correct_cwd: Any): def test_train_multi_eval(set_correct_cwd: Any): - """Test training run with a small dataset.""" - dataset_name = create_c4_dataset_xsmall('cpu-gauntlet') - test_cfg = gpt_tiny_cfg(dataset_name, 'cpu') + """Test training run with multiple eval datasets.""" + c4_dataset_name = create_c4_dataset_xsmall('multi-eval') + test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') # Set up multiple eval dataloaders first_eval_loader = test_cfg.eval_loader - first_eval_loader.label = 'eval_1' + first_eval_loader.label = 'c4' + # Create second eval dataloader using the arxiv dataset. second_eval_loader = copy.deepcopy(first_eval_loader) - second_eval_loader.label = 'eval_2' + arxiv_dataset_name = create_arxiv_dataset('multi-eval') + second_eval_loader.data_local = arxiv_dataset_name + second_eval_loader.label = 'arxiv' test_cfg.eval_loader = om.create([first_eval_loader, second_eval_loader]) test_cfg.eval_subset_num_batches = 1 # -1 to evaluate on all batches @@ -179,23 +208,21 @@ def test_train_multi_eval(set_correct_cwd: Any): print(inmemorylogger.data.keys()) # Checks for first eval dataloader - assert 'metrics/eval/eval_1/LanguageCrossEntropy' in inmemorylogger.data.keys( - ) + assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() assert isinstance( - inmemorylogger.data['metrics/eval/eval_1/LanguageCrossEntropy'], list) + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list) assert len( - inmemorylogger.data['metrics/eval/eval_1/LanguageCrossEntropy'][-1]) > 0 + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/eval_1/LanguageCrossEntropy'][-1], - tuple) + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple) # Checks for second eval dataloader - assert 'metrics/eval/eval_2/LanguageCrossEntropy' in inmemorylogger.data.keys( + assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys( ) assert isinstance( - inmemorylogger.data['metrics/eval/eval_2/LanguageCrossEntropy'], list) + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list) assert len( - inmemorylogger.data['metrics/eval/eval_2/LanguageCrossEntropy'][-1]) > 0 + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/eval_2/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1], tuple)