Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GemmaTokenizerFast word_ids() returns only zeros #31437

Closed
1 of 4 tasks
Alienmaster opened this issue Jun 15, 2024 · 10 comments
Closed
1 of 4 tasks

GemmaTokenizerFast word_ids() returns only zeros #31437

Alienmaster opened this issue Jun 15, 2024 · 10 comments

Comments

@Alienmaster
Copy link

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.41.2
  • Platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.23.1
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The method word_ids() does only return a list of zeros instead of the correct word_ids.

sentence = "I love my cat"
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("google/Gemma-7b") #-version a0eac5b
encoded = tokenizer(sentence, return_tensors="pt")
print(encoded.word_ids())
# [None, 0, 0, 0, 0]

I tried several variations of configurations stated in the linked issues in #28881 , but for Gemma it doesn't change the result. The llama3 tokenizer outputs the correct values with this code.

Expected behavior

The output of word_ids should look like
[None, 0, 1, 2, 3]

@ArthurZucker
Copy link
Collaborator

Hey! Will have a look thanks for reporting

@huggingface huggingface deleted a comment from github-actions bot Jul 16, 2024
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jul 16, 2024

It seems that we need this:

tokenizer._tokenizer.pre_tokenizer = Sequence([Split("▁","merged_with_next")])
encoded = tokenizer(sentence, return_tensors="pt")
print(encoded.word_ids())
[None, 0, 1, 2, 3]

@ArthurZucker
Copy link
Collaborator

#32191 should fix this !

@xenova
Copy link
Contributor

xenova commented Jul 30, 2024

Not included in #32191 since the proposed fix breaks encoding. Will need to do in a follow-up PR :)

@ArthurZucker
Copy link
Collaborator

Small update:

pre_tokenzier = pre_tokenizers.Split(Regex('(?<!▁)▁'), "merged_with_next")
pre_tokenzier.pre_tokenize_str(sentence.replace(" ", "▁"))
Out[41]: [('I', (0, 1)), ('▁love', (1, 6)), ('▁my', (6, 9)), ('▁cat', (9, 13))]
tokenizer(sentence).tokens()
Out[48]: ['<bos>', 'I', '▁love', '▁my', '▁cat']
tokenizer(sentence).word_ids()
Out[49]: [None, 0, 1, 2, 3]

Gives somewhat acceptable results,

@ArthurZucker
Copy link
Collaborator

But I don't recommend word ids, separations are "brittle" and this changes the output of the tokenization

@Alienmaster
Copy link
Author

Is there a better way to link the tokens to words than word_ids?

@ArthurZucker
Copy link
Collaborator

Offsets!

@ArthurZucker
Copy link
Collaborator

The offset mapping give you the exact place where the token corresponds to in the original string

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants