Skip to content

Commit

Permalink
Fix split_batch bug with empty generation_kwargs (mosaicml#2913)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxisawesome authored Jan 29, 2024
1 parent 28a9a23 commit 01eb20d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
13 changes: 8 additions & 5 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ def __init__(
self.tokenize_labels = tokenize_labels
self.batch_mapping = batch_mapping or {}
self.base_batch = base_batch or {}
self.update_generation_kwargs(generation_kwargs or {})
if generation_kwargs:
self.update_generation_kwargs(generation_kwargs)

self.static_keys = static_keys
self.list_keys = list_keys
Expand Down Expand Up @@ -358,9 +359,9 @@ def update_generation_kwargs(self, generation_kwargs: Dict) -> None:
Args:
dict: Keyword arguments that be written into base_batch['generation_kwargs']
"""
if 'generation_kwargs' not in self.base_batch:
self.base_batch['generation_kwargs'] = {}
if generation_kwargs:
if 'generation_kwargs' not in self.base_batch:
self.base_batch['generation_kwargs'] = {}
self.base_batch['generation_kwargs'].update(generation_kwargs)

def read_dataset(self,
Expand Down Expand Up @@ -702,7 +703,8 @@ def __init__(self,
'input_ids': self.context_key,
'labels': 'aliases',
}
self.update_generation_kwargs(kwargs.get('generation_kwargs', {}))
if 'generation_kwargs' in kwargs:
self.update_generation_kwargs(kwargs['generation_kwargs'])

def read_dataset(
self,
Expand Down Expand Up @@ -1292,7 +1294,8 @@ def __init__(
'eos_token_id': self.tokenizer.eos_token_id
}
}
self.update_generation_kwargs(kwargs.get('generation_kwargs', {}))
if 'generation_kwargs' in kwargs:
self.update_generation_kwargs(kwargs['generation_kwargs'])

def _set_max_prompt_and_answer_lengths(self):
"""
Expand Down
50 changes: 49 additions & 1 deletion tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,55 @@ def test_update_generation_kwargs_no_kwargs(tiny_gpt2_tokenizer, tmp_path):
destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map)
assert not dl.base_batch['generation_kwargs']
assert not 'generation_kwargs' in dl.base_batch


def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path):
pytest.importorskip('datasets')
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/triviaqa_small.jsonl'
transformers = pytest.importorskip('transformers')
tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable

tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = InContextLearningQATaskDataset(dataset_uri=dataset_uri,
tokenizer=tokenizer,
max_seq_len=1024,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=0,
fewshot_random_seed=1234,
prompt_string='',
example_delimiter='\n',
continuation_delimiter=': ',
destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
generation_kwargs=None)
assert len(dl.base_batch['generation_kwargs']) == 3


def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path):
pytest.importorskip('datasets')
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/triviaqa_small.jsonl'
transformers = pytest.importorskip('transformers')
tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable

tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = InContextLearningQATaskDataset(dataset_uri=dataset_uri,
tokenizer=tokenizer,
max_seq_len=1024,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=0,
fewshot_random_seed=1234,
prompt_string='',
example_delimiter='\n',
continuation_delimiter=': ',
destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
generation_kwargs={'temperature': 0.9})
assert 'generation_kwargs' in dl.base_batch
assert dl.base_batch['generation_kwargs']['temperature'] == 0.9
assert len(dl.base_batch['generation_kwargs']) == 4


@pytest.mark.filterwarnings(
Expand Down

0 comments on commit 01eb20d

Please sign in to comment.