Skip to content

Commit

Permalink
and more nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Oct 11, 2023
1 parent 8bcb345 commit d9e5fad
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 15 deletions.
14 changes: 2 additions & 12 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,22 +2172,17 @@ def _from_pretrained(

additional_special_tokens = init_kwargs.pop("additional_special_tokens", None) or []
added_tokens_decoder = {}
# added_tokens_map = {}
legacy_saved = "added_tokens_decoder" not in init_kwargs
if not legacy_saved:
for idx, token in init_kwargs["added_tokens_decoder"].items():
if isinstance(token, dict):
token = AddedToken(**token)
if isinstance(token, AddedToken):
added_tokens_decoder[int(idx)] = token
# added_tokens_map[str(token)] = token
else:
raise ValueError(
f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary."
)
# for key in cls.SPECIAL_TOKENS_ATTRIBUTES:
# if key in init_kwargs:
# init_kwargs[key] = added_tokens_map[init_kwargs[key]]
else:
# begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified
if special_tokens_map_file is not None:
Expand All @@ -2200,22 +2195,17 @@ def _from_pretrained(
continue
if isinstance(value, dict):
value = AddedToken(**value)
value.special = True
init_kwargs[key] = value
elif key == "additional_special_tokens" and isinstance(value, list):
for token in value:
token = AddedToken(**token, special=True) if isinstance(token, dict) else token
token = AddedToken(**token) if isinstance(token, dict) else token
if token not in additional_special_tokens:
additional_special_tokens.append(token)
else:
init_kwargs[key] = value
all_special_strings = [str(token) for token in additional_special_tokens]
# also add the other special tokens
all_special_strings += [
str(init_kwargs.get(key, ""))
for key in cls.SPECIAL_TOKENS_ATTRIBUTES
if init_kwargs.get(key, None) not in all_special_strings
]
all_special_strings += [str(init_kwargs[key]) for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys()]
# slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`.
if added_tokens_file is not None:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4071,20 +4071,20 @@ def test_added_tokens_serialization(self):
if self.rust_tokenizer_class is not None:
# fast from pretrained ignore the lstrip rstrip
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
self.assertEquals(tokenizer_fast._eos_token, str(new_eos))
self.assertEquals(tokenizer.added_tokens_decoder[tokenizer.eos_token_id], new_eos)
self.assertIn(new_eos, tokenizer_fast.added_tokens_decoder.values())
self.assertEqual(EXPECTED_ADDED_TOKENS_DECODER, tokenizer_fast.added_tokens_decoder)
with tempfile.TemporaryDirectory() as tmp_dir_3:
tokenizer_fast.save_pretrained(tmp_dir_3)
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(
pretrained_name, eos_token=new_eos, use_fast=True
)
self.assertEquals(tokenizer._eos_token, new_eos)
self.assertEquals(tokenizer.added_tokens_decoder[tokenizer.eos_token_id], new_eos)
self.assertIn(new_eos, tokenizer.added_tokens_decoder.values())
self.assertEqual(EXPECTED_ADDED_TOKENS_DECODER, tokenizer.added_tokens_decoder)

tokenizer = self.tokenizer_class.from_pretrained(tmp_dir_2)
self.assertEquals(tokenizer._eos_token, new_eos)
self.assertEquals(tokenizer.added_tokens_decoder[tokenizer.eos_token_id], new_eos)
self.assertIn(new_eos, tokenizer.added_tokens_decoder.values())
self.assertEqual(EXPECTED_ADDED_TOKENS_DECODER, tokenizer.added_tokens_decoder)

Expand Down

0 comments on commit d9e5fad

Please sign in to comment.