From 01eb20d45fa87b3992111d2a3c610cbc397bc402 Mon Sep 17 00:00:00 2001 From: Max Marion Date: Mon, 29 Jan 2024 12:57:13 -0800 Subject: [PATCH] Fix split_batch bug with empty generation_kwargs (#2913) --- .../in_context_learning_evaluation.py | 13 +++-- .../test_in_context_learning_datasets.py | 50 ++++++++++++++++++- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/composer/datasets/in_context_learning_evaluation.py b/composer/datasets/in_context_learning_evaluation.py index 4e0e30f1ff..2fc75cf899 100644 --- a/composer/datasets/in_context_learning_evaluation.py +++ b/composer/datasets/in_context_learning_evaluation.py @@ -317,7 +317,8 @@ def __init__( self.tokenize_labels = tokenize_labels self.batch_mapping = batch_mapping or {} self.base_batch = base_batch or {} - self.update_generation_kwargs(generation_kwargs or {}) + if generation_kwargs: + self.update_generation_kwargs(generation_kwargs) self.static_keys = static_keys self.list_keys = list_keys @@ -358,9 +359,9 @@ def update_generation_kwargs(self, generation_kwargs: Dict) -> None: Args: dict: Keyword arguments that be written into base_batch['generation_kwargs'] """ - if 'generation_kwargs' not in self.base_batch: - self.base_batch['generation_kwargs'] = {} if generation_kwargs: + if 'generation_kwargs' not in self.base_batch: + self.base_batch['generation_kwargs'] = {} self.base_batch['generation_kwargs'].update(generation_kwargs) def read_dataset(self, @@ -702,7 +703,8 @@ def __init__(self, 'input_ids': self.context_key, 'labels': 'aliases', } - self.update_generation_kwargs(kwargs.get('generation_kwargs', {})) + if 'generation_kwargs' in kwargs: + self.update_generation_kwargs(kwargs['generation_kwargs']) def read_dataset( self, @@ -1292,7 +1294,8 @@ def __init__( 'eos_token_id': self.tokenizer.eos_token_id } } - self.update_generation_kwargs(kwargs.get('generation_kwargs', {})) + if 'generation_kwargs' in kwargs: + self.update_generation_kwargs(kwargs['generation_kwargs']) def _set_max_prompt_and_answer_lengths(self): """ diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py index 11a9a94c7b..de2114dd76 100644 --- a/tests/datasets/test_in_context_learning_datasets.py +++ b/tests/datasets/test_in_context_learning_datasets.py @@ -273,7 +273,55 @@ def test_update_generation_kwargs_no_kwargs(tiny_gpt2_tokenizer, tmp_path): destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map) - assert not dl.base_batch['generation_kwargs'] + assert not 'generation_kwargs' in dl.base_batch + + +def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path): + pytest.importorskip('datasets') + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningQATaskDataset(dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generation_kwargs=None) + assert len(dl.base_batch['generation_kwargs']) == 3 + + +def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path): + pytest.importorskip('datasets') + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/triviaqa_small.jsonl' + transformers = pytest.importorskip('transformers') + tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable + + tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + dl = InContextLearningQATaskDataset(dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=1024, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + fewshot_random_seed=1234, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + generation_kwargs={'temperature': 0.9}) + assert 'generation_kwargs' in dl.base_batch + assert dl.base_batch['generation_kwargs']['temperature'] == 0.9 + assert len(dl.base_batch['generation_kwargs']) == 4 @pytest.mark.filterwarnings(