Skip to content

Commit

Permalink
Fix pad_token_id=None for ICL evaluators (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhi-mosaic authored Feb 15, 2023
1 parent b49a7ef commit f46baab
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 9 deletions.
1 change: 1 addition & 0 deletions examples/cifar/tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from examples.cifar.tests.utils import SynthClassificationDirectory


@pytest.mark.skip()
@pytest.mark.parametrize('use_recipe', [True, False])
def test_trainer(use_recipe):
with open('yamls/resnet56.yaml') as f:
Expand Down
7 changes: 6 additions & 1 deletion examples/common/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def _validate_cfg(icl_cfg):
for icl_cfg in cfg.icl_tasks:
_validate_cfg(icl_cfg)
for num_fewshot in list(icl_cfg.num_fewshot):
if tokenizer.pad_token_id is None:
# Current workaround to support GPT2 tokenizer with `pad_token_id = None`
pad_tok_id = tokenizer.eos_token_id
else:
pad_tok_id = tokenizer.pad_token_id
label = f'{icl_cfg.label}/{num_fewshot}-shot'
metric_names = list(icl_cfg.metric_names)
dataloader = get_icl_task_dataloader(
Expand All @@ -115,7 +120,7 @@ def _validate_cfg(icl_cfg):
tokenizer,
batch_size=icl_cfg.batch_size,
max_seq_len=tokenizer.max_seq_len,
pad_tok_id=tokenizer.pad_token_id,
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=icl_cfg.prompt_string,
example_delimiter=icl_cfg.example_delimiter,
Expand Down
4 changes: 4 additions & 0 deletions examples/llm/src/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,9 @@ def pad_token_id(self):
def bos_token_id(self):
return self.tokenizer.bos_token_id

@property
def eos_token_id(self):
return self.tokenizer.eos_token_id


TOKENIZER_REGISTRY = {'hftokenizer': HFTokenizer}
5 changes: 3 additions & 2 deletions examples/llm/tests/test_c4_data_prep_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_download_script_from_api():
main(
Namespace(
**{
'splits': ['val'],
'splits': ['val_small'],
'out_root': './my-copy-c4-1',
'compression': None,
'concat_tokens': None,
Expand All @@ -32,6 +32,7 @@ def test_download_script_from_cmdline():
path = os.path.join(os.getcwd(), 'my-copy-c4-2')
shutil.rmtree(path, ignore_errors=True)
os.system(
'python ../common/convert_c4.py --out_root ./my-copy-c4-2 --splits val')
'python ../common/convert_c4.py --out_root ./my-copy-c4-2 --splits val_small'
)
assert os.path.exists(path)
shutil.rmtree(path, ignore_errors=False)
36 changes: 32 additions & 4 deletions examples/llm/tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
import shutil

import pytest
import torch
Expand All @@ -17,11 +18,36 @@ def get_config(conf_path='yamls/mosaic_gpt/125m.yaml'):
return test_cfg


def test_correct_padding(batch_size=32):
if not os.path.isdir('./my-copy-c4/val'):
pytest.xfail('c4 dataset not set up as expected')
@pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m'])
@pytest.mark.parametrize('pretokenize', [False, True])
def test_correct_padding(tokenizer_name, pretokenize, batch_size=4):
if tokenizer_name == 'gpt2' and not pretokenize:
pytest.xfail('Must pretokenize data if using "gpt2" tokenizer')

data_local = f'my-copy-c4-{tokenizer_name}-pretokenize-{pretokenize}'
split = 'val_small'
tokenizer_args = {
'gpt2': '--eos_text "<|endoftext|>"',
'facebook/opt-125m': '--bos_text "</s>"'
}[tokenizer_name]

path = os.path.join(os.getcwd(), data_local)
shutil.rmtree(path, ignore_errors=True)
if pretokenize:
os.system(
f'python ../common/convert_c4.py --out_root {path} --splits val_small --concat_tokens 2048 --tokenizer {tokenizer_name} {tokenizer_args}'
)
else:
os.system(
f'python ../common/convert_c4.py --out_root {path} --splits val_small'
)
if not os.path.isdir(path):
raise RuntimeError(f'c4 dataset at {path} not set up as expected')

test_cfg = get_config(conf_path='yamls/mosaic_gpt/125m.yaml')
test_cfg.tokenizer_name = tokenizer_name
test_cfg.data_local = data_local
test_cfg.eval_loader.dataset.split = split

# Dataloaders
eval_loader = build_text_dataloader(test_cfg.eval_loader, batch_size)
Expand All @@ -31,6 +57,8 @@ def test_correct_padding(batch_size=32):
assert batch['input_ids'].type() == 'torch.LongTensor'

# we follow the convention (from huggingface) that non-attended tokens are 0 in the attn mask and -100 in the labels
a = batch['attention_mask'] == 0
attention_mask = batch.get(
'attention_mask', torch.ones_like(batch['input_ids'], dtype=torch.bool))
a = attention_mask == 0
b = batch['labels'] == -100
assert torch.equal(a, b)
3 changes: 1 addition & 2 deletions examples/llm/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ def test_determinism(attention_type: str, precision):
test_cfg.model.init_device = 'cuda:0'
test_cfg.device = 'cuda:0'

model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model).to(
test_cfg.model.device)
model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model)
model_2 = copy.deepcopy(model_1)

optimizer_1 = DecoupledAdamW(model_1.parameters(),
Expand Down

0 comments on commit f46baab

Please sign in to comment.