From d9e5fad7c204575a81992db8c39c5b8c69b7ff61 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 11 Oct 2023 23:04:32 +0200 Subject: [PATCH] and more nits --- src/transformers/tokenization_utils_base.py | 14 ++------------ tests/test_tokenization_common.py | 6 +++--- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index bfb0015b894574..543e089997406a 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2172,7 +2172,6 @@ def _from_pretrained( additional_special_tokens = init_kwargs.pop("additional_special_tokens", None) or [] added_tokens_decoder = {} - # added_tokens_map = {} legacy_saved = "added_tokens_decoder" not in init_kwargs if not legacy_saved: for idx, token in init_kwargs["added_tokens_decoder"].items(): @@ -2180,14 +2179,10 @@ def _from_pretrained( token = AddedToken(**token) if isinstance(token, AddedToken): added_tokens_decoder[int(idx)] = token - # added_tokens_map[str(token)] = token else: raise ValueError( f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary." ) - # for key in cls.SPECIAL_TOKENS_ATTRIBUTES: - # if key in init_kwargs: - # init_kwargs[key] = added_tokens_map[init_kwargs[key]] else: # begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified if special_tokens_map_file is not None: @@ -2200,22 +2195,17 @@ def _from_pretrained( continue if isinstance(value, dict): value = AddedToken(**value) - value.special = True init_kwargs[key] = value elif key == "additional_special_tokens" and isinstance(value, list): for token in value: - token = AddedToken(**token, special=True) if isinstance(token, dict) else token + token = AddedToken(**token) if isinstance(token, dict) else token if token not in additional_special_tokens: additional_special_tokens.append(token) else: init_kwargs[key] = value all_special_strings = [str(token) for token in additional_special_tokens] # also add the other special tokens - all_special_strings += [ - str(init_kwargs.get(key, "")) - for key in cls.SPECIAL_TOKENS_ATTRIBUTES - if init_kwargs.get(key, None) not in all_special_strings - ] + all_special_strings += [str(init_kwargs[key]) for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys()] # slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`. if added_tokens_file is not None: with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 308963ae1bc73f..3eb884ae045dc3 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -4071,7 +4071,7 @@ def test_added_tokens_serialization(self): if self.rust_tokenizer_class is not None: # fast from pretrained ignore the lstrip rstrip tokenizer_fast = self.rust_tokenizer_class.from_pretrained(tmp_dir_2) - self.assertEquals(tokenizer_fast._eos_token, str(new_eos)) + self.assertEquals(tokenizer.added_tokens_decoder[tokenizer.eos_token_id], new_eos) self.assertIn(new_eos, tokenizer_fast.added_tokens_decoder.values()) self.assertEqual(EXPECTED_ADDED_TOKENS_DECODER, tokenizer_fast.added_tokens_decoder) with tempfile.TemporaryDirectory() as tmp_dir_3: @@ -4079,12 +4079,12 @@ def test_added_tokens_serialization(self): tokenizer_fast = self.rust_tokenizer_class.from_pretrained( pretrained_name, eos_token=new_eos, use_fast=True ) - self.assertEquals(tokenizer._eos_token, new_eos) + self.assertEquals(tokenizer.added_tokens_decoder[tokenizer.eos_token_id], new_eos) self.assertIn(new_eos, tokenizer.added_tokens_decoder.values()) self.assertEqual(EXPECTED_ADDED_TOKENS_DECODER, tokenizer.added_tokens_decoder) tokenizer = self.tokenizer_class.from_pretrained(tmp_dir_2) - self.assertEquals(tokenizer._eos_token, new_eos) + self.assertEquals(tokenizer.added_tokens_decoder[tokenizer.eos_token_id], new_eos) self.assertIn(new_eos, tokenizer.added_tokens_decoder.values()) self.assertEqual(EXPECTED_ADDED_TOKENS_DECODER, tokenizer.added_tokens_decoder)