Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sentence start got unexpected space #25285

Closed
lucasjinreal opened this issue May 26, 2023 · 5 comments
Closed

Sentence start got unexpected space #25285

lucasjinreal opened this issue May 26, 2023 · 5 comments

Comments

@lucasjinreal
Copy link

image

I got some iput_ids which encoded, after append some tgt_ids to input_ids, the new decoded sentences added weired spaces.

Here is the code:

from transformers import LlamaTokenizer

# any LLama tokenizer
tokenizer = LlamaTokenizer.from_pretrained("checkpoints/BiLLa-7B-LLM/tokenizer.model")

prefix = "Human: \n用python写一段快排\n\nAssistant: \n"
output = "OK, I will do for u!"
sentence_ids = tokenizer.encode(prefix, add_special_tokens=False)
b = tokenizer.decode(sentence_ids)

print(sentence_ids)
print(b)

input_ids = sentence_ids + tokenizer.encode(output, add_special_tokens=False)
input_ids += [tokenizer.eos_token_id]

o = tokenizer.decode(input_ids)
print(input_ids)
print()
print(o)

My output:

[12968, 29901, 29871, 13, 30406, 4691, 31479, 30287, 31559, 32815, 32996, 13, 13, 7900, 22137, 29901, 29871, 13]
Human: 
用python写一段快排

Assistant:

[12968, 29901, 29871, 13, 30406, 4691, 31479, 30287, 31559, 32815, 32996, 13, 13, 7900, 22137, 29901, 29871, 13, 9280, 29892, 306, 674, 437, 363, 318, 29991, 2]
 Human:
用python写一段快排

Assistant:
 OK, I will do for u!</s>

As you can see, both before Human and OK, there is an space, but actually not expected.

Why?

@chris-ha458
Copy link

2 things that could be problematic here

  1. a token that has a prefixspace (metasymbol for unigram or or accent G for BBPE etc)
  2. somewhere in the tokenizer chain there is a module within the tokenizer that is adding a prefix
    (add_prefix_space = True)

I checked and token 12968 does not have prefix space so it is not 1.

@chris-ha458
Copy link

@lucasjinreal
Copy link
Author

@chris-ha458 Thanks for taking look.

I depart this problem and digged a little bit, this is what I found:

Encode one sentence in 2 part (such as question + answer) without any space in them, and then concat the ids, compare with ecnoder the whole sentence at once, THEY ARE NOT SAME.

I don't know if this is expected, but this is out of my expecetations.

For detail, please run this script on any llama tokeizer:

from transformers import LlamaTokenizer

# any LLama tokenizer
tokenizer = LlamaTokenizer.from_pretrained("checkpoints/BiLLa-7B-LLM/tokenizer.model")

def test1():
    prefix = "Human:\n用 python 写一段快排\n\nAssistant:"
    output = "OK, I will do for u!"
    sentence_ids = tokenizer.encode(prefix, add_special_tokens=False)
    # b = tokenizer.decode(sentence_ids)
    print(sentence_ids)
    d = tokenizer.encode(output, add_special_tokens=False)
    print(d)
    input_ids = sentence_ids + d
    # input_ids += [tokenizer.eos_token_id]
    o = tokenizer.decode(input_ids)
    print(input_ids)
    print(o)


def test2():
    print('---------------- test2')
    prefix = "Human:\n用 python 写一段快排\n\nAssistant:"
    output = "OK, I will do for u!"

    sentence = prefix + output
    sentence_ids = tokenizer.encode(sentence, add_special_tokens=False)
    b = tokenizer.decode(sentence_ids)
    print(sentence_ids)
    print(b)

    c = tokenizer.decode([12968])
    print(c)
    c = tokenizer.decode([9280])
    print(c)
    c = tokenizer.decode([8949])
    print(c)


if __name__ == '__main__':
    test1()
    test2()

Here is interesting thing:

the 2 way to encode same sentence got different ids:

[12968, 29901, 13, 30406, 3017, 29871, 31479, 30287, 31559, 32815, 32996, 13, 13, 7900, 22137, 29901, 9280, 29892, 306, 674, 437, 363, 318, 29991]
[12968, 29901, 13, 30406, 3017, 29871, 31479, 30287, 31559, 32815, 32996, 13, 13, 7900, 22137, 29901, 8949, 29892, 306, 674, 437, 363, 318, 29991]

And I decode the different ids that might caused space, they actually same character.......

So I am totally missed here....

@ArthurZucker
Copy link
Collaborator

Hey! This issue has nothing to do with tokenizers since it uses the slow tokenizer. I believe that this will be fixed by #25224

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Sep 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants