-
Notifications
You must be signed in to change notification settings - Fork 816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
After extra tokens are added, decoded results has no whitespace between tokens #826
Comments
Encoding is a destructive process, meaning For a simple example, imagine if the tokenizer uses lowercase, then there's no way to know which letters were lowercased or not. That being said: Easy fix 1 from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-small')
new_tokens = [' new_token_1', ' new_token_2'] # Extra space
tokenizer.add_tokens(new_tokens)
a = tokenizer.encode('This new_token_1 differs from new_token_2')
print(tokenizer.decode(a))
>>> 'This new_token_1 differs from new_token_2</s>' This has the drawback that sentences starting with Easy fix 2: from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-small')
new_tokens = ['new_token_1', 'new_token_2']
string = 'This new_token_1 differs from new_token_2'
inputs = tokenizer(string, return_offsets_mapping=True)
inputs["input_ids"][1]
>>> 32100 # The added token
start, stop = inputs["offset_mapping"][1]
>>> (5, 16)
string[start:stop]
>>> "new_token_1" That sort of logic will work regardless of the destructiveness of the |
Hi @Narsil, thanks for your input. I think my problem description was not clear enough so I edit it. The Easy Fix 1 cannot ahieve what I want. Manually adding whitespaces before new tokens will mislead the tokenizer as it can only recognize "~new_token_1" instead of "new_token_1" ( from transformers import AutoTokenizer
tokenizer_without_space = AutoTokenizer.from_pretrained('t5-small')
tokenizer_with_space = AutoTokenizer.from_pretrained('t5-small')
new_tokens_without_space = ['new_token_1', 'new_token_2']
new_tokens_with_space = [' new_token_1', ' new_token_2'] # Manually add space before tokens
tokenizer_without_space.add_tokens(new_tokens_without_space)
tokenizer_with_space.add_tokens(new_tokens_with_space)
test_string = 'new_token_1 differs from new_token_2'
print(f"Tokenizer without prefix space: {tokenizer_without_space.encode(test_string)}")
print(f"Tokenizer with prefix space: {tokenizer_with_space.encode(test_string)}")
>>> Tokenizer without prefix space: [32100, 7641, 7, 45, 32101, 1]
>>> Tokenizer with prefix space: [126, 834, 235, 2217, 834, 536, 7641, 7, 45, 32101, 1]
As for Easy Fix 2, it cannot solve the issue in the editted description either. |
I mentioned the caveat of the first fix too. For the second one, if you don't have access to the original string then there's not option to magically add spaces sometimes. "[CLS][SEP]" is what will be decoded by default. As I mentioned, the Could you share a bit more why adding that space is that important in your use case ? It might be important enough to add such an option. |
For me personally, I have also hit this issue. When calculating things like rouge score, if you are predicting special tokens that are missing spaces, the rouge score won't be quite as accurate when compared to ground truth. I noticed that the T5 and Pegasus tokenizers have the behaviour OP described (missing spaces), while BART and ProphetNet tokenizers behave the way OP wants (spaces between special tokens). |
Shouldn't ROUGE be calculated on pure text and exclude special tokens ? It seems to my untrained eye that including any kind of special token while doing benchmarks very risky at best. Only exception I can think of is |
It's a bit of a long story, but basically, my output has special tokens during the ROUGE calculation to avoid places where the model fails to put the full word in the output, and instead just has to predict a single token. This way it simplifies my problem. This use-case is super specific to an experiment I'm doing though haha, definitely not recommended usually. Just thought I'd add my experience 👍🏻 |
This is a bug @Narsil |
As mentionned before, there's no way to recover the original string. The flag is global meaning you either ALWAYS add a space, or NEVER. As I said, decoding is a choice given ids. @SaulLu I think we could add this actually and reduce the list of differences between slow and fast. Wdyt ? |
@Narsil from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small') #switch from AutoTokenizer to T5Tokenizer
new_tokens = ['new_token_1', 'new_token_2']
tokenizer.add_tokens(new_tokens)
existing_token_ids = tokenizer.encode('differs from')
print(existing_token_ids[:-1])
>>> [7641, 7, 45]
new_token_ids = tokenizer.encode('new_token_1 new_token_2')
print(new_token_ids[:-1])
>>> [32100, 32101]
print(tokenizer.decode([32100, 32101]))
>>> new_token_1 new_token_2 # there is a space between "new_token_1" and "new_token_2"
print(tokenizer.decode([32100, 7641, 45, 32101]))
>>> new_token_1 differ from new_token_2 # again, there is a space between "from" and "new_token_2" |
Yes I agree with you. There are still some differences between slow and fast tokenizers, some are easier to fix than others (like this one). Others are more tricky. We have an umbrella issue in transformers for these, and the question is whether we should aim for entire 1:1 parity (it's a lot of work and will imply small but existing breaking changes, which we really try to avoid). Thanks for reporting that the slow does add spaces, I wasn't aware of this (since some don't enable it and usually for tokens like |
Is this error related to this? huggingface/trl#588 |
Hey! I think #1357 should help with the extra spaces |
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days. |
Problem
I add a set of some extra tokens to a tokenizer (t5-small). The decoder tokenizer is expected to output tokens mostly sampled from this set. However, the decoded string has no whitespace between tokens.
Environment
tokenizers == 0.10.3
Reproduce
The tokenizer has no problem in decoding tokens that already exist in a pre-trained vocab (e.g. "differs" and "from") while fails to insert whitespace before the newly added tokens. I learned from old issues (#73, #232 and this comment) that one possible way is to train a tokenizer on my own. I am wondering if there is any other quick solution to fix this problem?
The text was updated successfully, but these errors were encountered: