Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nit-added-tokens #26538

Merged
merged 16 commits into from
Oct 3, 2023
11 changes: 8 additions & 3 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2199,6 +2199,11 @@ def _from_pretrained(

if isinstance(token, AddedToken):
added_tokens_decoder[int(idx)] = token
if str(token) in additional_special_tokens:
# at this point if the token is in `additional_special_tokens` as an str, should be updated
additional_special_tokens.remove(str(token))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only use the default legacy values for AddedToken if the token is not already in the added tokens decoder

if token.special and token not in additional_special_tokens:
additional_special_tokens.append(token)
else:
raise ValueError(
f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary."
Expand Down Expand Up @@ -2227,7 +2232,7 @@ def _from_pretrained(
if added_tokens_file is not None:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle)
# legacy: we have to init with (rstrip=True, lstrip=True)
# legacy: we have to init with (rstrip=True, lstrip=True) (if the token is new? Failing test)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might have to update this. The tests are shitty and the default is biting us

strip = True if "Fast" not in cls.__name__ else False
added_tokens_decoder = {
index: AddedToken(token, rstrip=strip, lstrip=strip) for token, index in added_tok_encoder.items()
Expand Down Expand Up @@ -2382,8 +2387,8 @@ def save_pretrained(
tokenizer_config = copy.deepcopy(self.init_kwargs)

# TODO: Ensure the modified attributes (those are also in the __init__ kwargs) will give identical tokenizers
# target_keys = self.init_kwargs.keys()
target_keys = ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"]
target_keys = list(self.init_kwargs.keys())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when saving; we should overwrite the init_kwargs with the content of self. Don't know why it was not the case before

target_keys += ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"]
for k in target_keys:
if hasattr(self, k):
tokenizer_config[k] = getattr(self, k)
Expand Down
Loading