Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

After extra tokens are added, decoded results has no whitespace between tokens #826

Closed
YiweiJiang2015 opened this issue Nov 10, 2021 · 13 comments
Labels

Comments

@YiweiJiang2015
Copy link

YiweiJiang2015 commented Nov 10, 2021

Problem

I add a set of some extra tokens to a tokenizer (t5-small). The decoder tokenizer is expected to output tokens mostly sampled from this set. However, the decoded string has no whitespace between tokens.

Environment

tokenizers == 0.10.3

Reproduce

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-small')
new_tokens = ['new_token_1', 'new_token_2']
tokenizer.add_tokens(new_tokens)
existing_token_ids = tokenizer.encode('differs from')
print(existing_token_ids[:-1])
>>> [7641, 7, 45]
new_token_ids = tokenizer.encode('new_token_1 new_token_2')
print(new_token_ids[:-1])
>>> [32100, 32101]
print(tokenizer.decode([32100, 32101]))
>>> 'new_token_1new_token_2</s>' # there is no space between "new_token_1" and "new_token_2"
print(tokenizer.decode([32100, 7641, 45, 32101])) 
>>> new_token_1 differ fromnew_token_2 # again, there is no space between "from" and "new_token_2"

The tokenizer has no problem in decoding tokens that already exist in a pre-trained vocab (e.g. "differs" and "from") while fails to insert whitespace before the newly added tokens. I learned from old issues (#73, #232 and this comment) that one possible way is to train a tokenizer on my own. I am wondering if there is any other quick solution to fix this problem?

@Narsil
Copy link
Collaborator

Narsil commented Nov 11, 2021

Encoding is a destructive process, meaning tokenizer.decode(tokenizer.encde(mystring)) won't in the general case give you back your original string.

For a simple example, imagine if the tokenizer uses lowercase, then there's no way to know which letters were lowercased or not.

That being said:

Easy fix 1

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-small')
new_tokens = [' new_token_1', ' new_token_2']  # Extra space
tokenizer.add_tokens(new_tokens)
a = tokenizer.encode('This new_token_1 differs from new_token_2')
print(tokenizer.decode(a))
>>> 'This new_token_1 differs from new_token_2</s>'

This has the drawback that sentences starting with "new_token_1" won't recognize the special token (since it's now " new_token_1" (with space prefixed))

Easy fix 2:
Don't use decoding to see what was used but instead use offsets_mapping

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-small')
new_tokens = ['new_token_1', 'new_token_2']
string = 'This new_token_1 differs from new_token_2'
inputs = tokenizer(string, return_offsets_mapping=True)
inputs["input_ids"][1]
>>> 32100  # The added token
start, stop = inputs["offset_mapping"][1]
>>> (5, 16)
string[start:stop]
>>> "new_token_1"

That sort of logic will work regardless of the destructiveness of the encode operation because you're doing lookups in the original string now.

@YiweiJiang2015
Copy link
Author

YiweiJiang2015 commented Nov 11, 2021

Hi @Narsil, thanks for your input. I think my problem description was not clear enough so I edit it.

The Easy Fix 1 cannot ahieve what I want. Manually adding whitespaces before new tokens will mislead the tokenizer as it can only recognize "~new_token_1" instead of "new_token_1" (~ represents whitespace). For example,

from transformers import AutoTokenizer
tokenizer_without_space = AutoTokenizer.from_pretrained('t5-small')
tokenizer_with_space = AutoTokenizer.from_pretrained('t5-small')
new_tokens_without_space =  ['new_token_1', 'new_token_2']
new_tokens_with_space = [' new_token_1', ' new_token_2']  # Manually add space before tokens
tokenizer_without_space.add_tokens(new_tokens_without_space)
tokenizer_with_space.add_tokens(new_tokens_with_space)

test_string = 'new_token_1 differs from new_token_2'
print(f"Tokenizer without prefix space: {tokenizer_without_space.encode(test_string)}")
print(f"Tokenizer with prefix space: {tokenizer_with_space.encode(test_string)}")
>>> Tokenizer without prefix space: [32100, 7641, 7, 45, 32101, 1]
>>> Tokenizer with prefix space: [126, 834, 235, 2217, 834, 536, 7641, 7, 45, 32101, 1]

tokenizer_with_space will break "new_token_1" into pieces (i.e., ["new", "", "to", "ken", "", "1"]).

As for Easy Fix 2, it cannot solve the issue in the editted description either.

@Narsil
Copy link
Collaborator

Narsil commented Nov 11, 2021

I mentioned the caveat of the first fix too.

For the second one, if you don't have access to the original string then there's not option to magically add spaces sometimes.

"[CLS][SEP]" is what will be decoded by default.
Adding a space to give "[CLS] [SEP]", might lead to break the invariant tokenizer.encode(tokenizer.decode(ids)) == ids which is always valid (contrary to the other one, which does lose information).

As I mentioned, the encode process is destructive and does lose information, you cannot recover what the string you expect from just the ids. I mean you could add some logic yourself to create a space if it's not the first token, but there's no guarantee that this was the original string. Currently there is no option to do that in the library.

Could you share a bit more why adding that space is that important in your use case ? It might be important enough to add such an option.

@logan-markewich
Copy link

For me personally, I have also hit this issue. When calculating things like rouge score, if you are predicting special tokens that are missing spaces, the rouge score won't be quite as accurate when compared to ground truth.

I noticed that the T5 and Pegasus tokenizers have the behaviour OP described (missing spaces), while BART and ProphetNet tokenizers behave the way OP wants (spaces between special tokens).

@Narsil
Copy link
Collaborator

Narsil commented Dec 8, 2021

Shouldn't ROUGE be calculated on pure text and exclude special tokens ? It seems to my untrained eye that including any kind of special token while doing benchmarks very risky at best. Only exception I can think of is <unk> maybe. And it would probably always need a special treatment since <unk> is a place holder for 1 to n tokens (hence it can be a placeholder for any amount of words)

@logan-markewich
Copy link

It's a bit of a long story, but basically, my output has special tokens during the ROUGE calculation to avoid places where the model fails to put the full word in the output, and instead just has to predict a single token. This way it simplifies my problem.

This use-case is super specific to an experiment I'm doing though haha, definitely not recommended usually. Just thought I'd add my experience 👍🏻

@edchengg
Copy link

edchengg commented Jun 2, 2022

This is a bug @Narsil
Using the T5Tokenizer and install sentencepiece library will fix this issue.
DO NOT use AutoTokenizer.

@Narsil
Copy link
Collaborator

Narsil commented Jun 2, 2022

@edchengg ,

As mentionned before, there's no way to recover the original string.
That being said, the slow tokenizers do have an option called spaces_between_special_tokens which this library doesn't.

The flag is global meaning you either ALWAYS add a space, or NEVER. As I said, decoding is a choice given ids.

@SaulLu I think we could add this actually and reduce the list of differences between slow and fast. Wdyt ?

@edchengg
Copy link

edchengg commented Jun 2, 2022

@Narsil
To be clear
I understand there is no way to recover the original string.
But the goal in this thread is to recover the whitespace before special tokens.
And the solution is to call T5Tokenizer directly instead of using AutoTokenizer.

from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small') #switch from AutoTokenizer to T5Tokenizer
new_tokens = ['new_token_1', 'new_token_2']
tokenizer.add_tokens(new_tokens)
existing_token_ids = tokenizer.encode('differs from')
print(existing_token_ids[:-1])
>>> [7641, 7, 45]
new_token_ids = tokenizer.encode('new_token_1 new_token_2')
print(new_token_ids[:-1])
>>> [32100, 32101]
print(tokenizer.decode([32100, 32101]))
>>> new_token_1 new_token_2 # there is a space between "new_token_1" and "new_token_2"
print(tokenizer.decode([32100, 7641, 45, 32101])) 
>>> new_token_1 differ from new_token_2  # again, there is a space between "from" and "new_token_2"

@Narsil
Copy link
Collaborator

Narsil commented Jun 3, 2022

@edchengg ,

Yes I agree with you. There are still some differences between slow and fast tokenizers, some are easier to fix than others (like this one). Others are more tricky. We have an umbrella issue in transformers for these, and the question is whether we should aim for entire 1:1 parity (it's a lot of work and will imply small but existing breaking changes, which we really try to avoid).

Thanks for reporting that the slow does add spaces, I wasn't aware of this (since some don't enable it and usually for tokens like [SEP][CLS] it doesn't matter as much.)

@datorresb
Copy link

Is this error related to this? huggingface/trl#588

@ArthurZucker
Copy link
Collaborator

Hey! I think #1357 should help with the extra spaces

Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Mar 11, 2024
@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Mar 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

6 participants