Skip to content

Commit

Permalink
Lower sequence generation length on code gen to be dependent on max c…
Browse files Browse the repository at this point in the history
…anonical solution length (mosaicml#2682)

* sequentialize generations_per_sample

* fix bug

* lower generation length

* lower generation length

* lower generation length

* fix gen len

* restore

* restore

* restore

* fix tests

* fix test
  • Loading branch information
bmosaicml authored Dec 4, 2023
1 parent b11f7b6 commit 5d20db1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
8 changes: 7 additions & 1 deletion composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ def __init__(
self.max_prompt_length = 0
self.top_p = top_p
self.top_k = top_k
self.max_answer_length = 0
fewshot_rng = random.Random(fewshot_random_seed)
self.encoded_dataset = self.prep_examples(num_fewshot, prompt_string, example_delimiter, code_prelimiter,
fewshot_rng)
Expand All @@ -1009,6 +1010,7 @@ def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter:
"""
max_prompt_length = 0
examples = []
max_answer_length = 0
for sample_idx in tqdm(range(len(self.samples))):
encoded_example = {}

Expand Down Expand Up @@ -1050,8 +1052,12 @@ def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter:
max_prompt_length = max(
max_prompt_length,
len(encoded_example['preamble']['input_ids'] + encoded_example['prompt']['input_ids']))
max_answer_length = max(
max_answer_length,
len(self.tokenizer(encoded_example['canonical_solution'], add_special_tokens=False)['input_ids']))

self.max_prompt_length = max_prompt_length
self.max_answer_length = max_answer_length + _MAX_ANSWER_BUFFER_LENGTH
return examples

def __getitem__(self, index):
Expand Down Expand Up @@ -1101,7 +1107,7 @@ def collate_fn(self, data):
'test_outputs': test_outputs, # list of test outputs
'languages': languages, # list of languages
'pass_at_k': self.pass_at_k,
'generation_length': self.max_seq_len - self.max_prompt_length,
'generation_length': min(self.max_answer_length, self.max_seq_len - self.max_prompt_length),
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'num_beams': 1, # single beam
Expand Down
6 changes: 3 additions & 3 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prom
assert tuple(batch['attention_mask'].shape) == (batch_size, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == seqlen - max_prompt_length
assert batch['generation_length'] == 129
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
Expand Down Expand Up @@ -860,7 +860,7 @@ def test_code_eval_test_cases(dataset_uri, tmp_path):
assert tuple(batch['attention_mask'].shape) == (batch_size, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == seqlen - max_prompt_length
assert batch['generation_length'] == 129
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

mod = types.ModuleType('test_module')
Expand Down Expand Up @@ -938,7 +938,7 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st
assert tuple(batch['attention_mask'].shape) == (batch_size, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == seqlen - max_prompt_length
assert batch['generation_length'] == 122
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
Expand Down

0 comments on commit 5d20db1

Please sign in to comment.