Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Oct 11, 2023
1 parent 112e4b1 commit 65aa232
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions src/transformers/models/t5/tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def __init__(
legacy=None,
**kwargs,
) -> None:
pad_token = AddedToken(pad_token, rstrip=True, lstrip=True)
unk_token = AddedToken(unk_token, rstrip=True, lstrip=True)
eos_token = AddedToken(eos_token, rstrip=True, lstrip=True)
pad_token = AddedToken(pad_token, rstrip=True, lstrip=True) if isinstance(pad_token, str) else pad_token
unk_token = AddedToken(unk_token, rstrip=True, lstrip=True) if isinstance(pad_token, str) else unk_token
eos_token = AddedToken(eos_token, rstrip=True, lstrip=True) if isinstance(pad_token, str) else eos_token

self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

Expand All @@ -167,7 +167,9 @@ def __init__(

if additional_special_tokens is not None:
extra_tokens = [x for x in additional_special_tokens if "<extra_id_" in str(x)]
if extra_ids > 0 and extra_ids != len(extra_tokens):
if len(extra_tokens) < 1:
additional_special_tokens += [f"<extra_id_{i}>" for i in range(extra_ids)]
elif extra_ids > 0 and extra_ids != len(extra_tokens):
raise ValueError(
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
" provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4110,7 +4110,7 @@ def test_additional_special_tokens_serialization(self):

# make sure the token was added
self.assertIn(new_eos, tokenizer.added_tokens_decoder.values())
self.assertEqual(new_eos, tokenizer.added_tokens_decoder[self.added_tokens_encoder[str(new_eos)]])
self.assertEqual(new_eos, tokenizer.added_tokens_decoder[tokenizer.added_tokens_encoder[str(new_eos)]])

# At this point if you save the tokenizer and reload it, the token will be saved as special
# it does not matter if you set the attribute
Expand Down

0 comments on commit 65aa232

Please sign in to comment.