Skip to content

Commit

Permalink
GGUF: Fix llama 3 GGUF (#31358)
Browse files Browse the repository at this point in the history
* Create push-important-models.yml

* llama3 support for GGUF

* fixup

* Update src/transformers/integrations/ggml.py

* fix pre-tokenizer

* fix

* fix

* fix

* fix

* fix

* fix

* address final comment

* handle special tokens + add tests
  • Loading branch information
younesbelkada authored Jun 20, 2024
1 parent 35b112d commit 6d43061
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 18 deletions.
90 changes: 75 additions & 15 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from array import array

import numpy as np
from tokenizers import Tokenizer, decoders
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
from tokenizers.models import BPE

from .. import AddedToken
Expand Down Expand Up @@ -540,15 +540,26 @@ def __init__(self, dict_):
self.merges = merges
else:
self.merges = [tuple(merge.split(" ")) for merge in self.merges]
if not hasattr(self, "scores"):
self.scores = [None for _ in range(len(self.tokens))]

if not hasattr(self, "added_tokens"):
self.added_tokens = []

if not hasattr(self, "unk_token_id"):
self.unk_token_id = None

# Llama2 uses the field `unknown_token_id`
if hasattr(self, "unknown_token_id") and self.unk_token_id is None:
self.unk_token_id = self.unknown_token_id


class GGUFLlamaConverter(LlamaConverter):
def __init__(self, tokenizer_dict):
self.proto = GGUFTokenizerSkeleton(tokenizer_dict)
self.original_tokenizer = self.proto
self.additional_kwargs = {}
self.is_llama_3_tokenizer = getattr(self.proto, "tokenizer_type", "llama") != "llama"

def vocab(self, proto):
return list(zip(proto.tokens, proto.scores))
Expand All @@ -560,22 +571,50 @@ def tokenizer(self, proto):
vocab_scores = self.vocab(self.proto)
merges = self.merges(self.proto)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(bpe_vocab, merges, unk_token=proto.tokens[proto.unk_token_id], fuse_unk=True, byte_fallback=True)
)
tokenizer.add_special_tokens(
[
AddedToken("<unk>", normalized=False, special=True),
AddedToken("<s>", normalized=False, special=True),
AddedToken("</s>", normalized=False, special=True),
]
)

unk_token = proto.tokens[proto.unk_token_id] if proto.unk_token_id is not None else None
bos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "bos_token_id", None) is not None else None
eos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "eos_token_id", None) is not None else None

tokenizer = Tokenizer(BPE(bpe_vocab, merges, unk_token=unk_token, fuse_unk=True, byte_fallback=True))

special_tokens = []

if not hasattr(self.proto, "token_type"):
if unk_token is not None:
special_tokens.append(AddedToken(unk_token, normalized=False, special=True))

if bos_token is not None:
special_tokens.append(AddedToken(bos_token, normalized=False, special=True))

if eos_token is not None:
special_tokens.append(AddedToken(eos_token, normalized=False, special=True))
else:
# 3 stands for special tokens
special_tokens_idx = np.where(np.array(self.proto.token_type) == 3)[0]

for idx in special_tokens_idx:
special_tokens.append(AddedToken(self.proto.tokens[idx], normalized=False, special=True))

if len(special_tokens) != 0:
tokenizer.add_special_tokens(special_tokens)

if len(self.proto.added_tokens) != 0:
tokenizer.add_special_tokens(
[AddedToken(added_token, normalized=False, special=False) for added_token in self.added_tokens]
tokenizer.add_tokens(
[AddedToken(added_token, normalized=False, special=False) for added_token in self.proto.added_tokens]
)

self.additional_kwargs["unk_token"] = unk_token
self.additional_kwargs["eos_token"] = bos_token
self.additional_kwargs["bos_token"] = eos_token

if self.is_llama_3_tokenizer:
self.additional_kwargs["add_prefix_space"] = False
self.additional_kwargs["clean_up_tokenization_spaces"] = True

self.additional_kwargs["legacy"] = False
self.original_tokenizer.legacy = False

return tokenizer

def decoder(self, replacement, add_prefix_space):
Expand All @@ -584,14 +623,34 @@ def decoder(self, replacement, add_prefix_space):
decoders.Fuse(),
decoders.Replace("▁", " "),
]

if self.is_llama_3_tokenizer:
sequence += [decoders.ByteLevel(add_prefix_space=False, trim_offsets=False, use_regex=True)]

if add_prefix_space:
sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence)

def converted(self):
tokenizer = super().converted()

# HACK: patch the llama-3 tokenizer to use the correspinding pre-tokenizer
# and normalizer
if self.is_llama_3_tokenizer:
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(
add_prefix_space=False, trim_offsets=False, use_regex=True
)
# This is tricky as the additional kwargs are passed after legacy is force-set in LlamaTokenizer's
# init.
tokenizer.normalizer = normalizers.Sequence([])

return tokenizer


class GGUFQwen2Converter(Qwen2Converter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
self.additional_kwargs = {}

def converted(self) -> Tokenizer:
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
Expand Down Expand Up @@ -629,5 +688,6 @@ def convert_gguf_tokenizer(architecture, tokenizer_dict) -> Tokenizer:
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
"""
tokenizer_class_name = architecture
converter_class = GGUF_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(tokenizer_dict).converted()
converter = GGUF_TO_FAST_CONVERTERS[tokenizer_class_name](tokenizer_dict)
fast_tokenizer = converter.converted()
return fast_tokenizer, converter.additional_kwargs
3 changes: 2 additions & 1 deletion src/transformers/models/llama/tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def __init__(
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565"
" https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file"
" you can ignore this message"
)
legacy = True

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/llama/tokenization_llama_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def __init__(
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565"
" https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file"
" you can ignore this message."
)
legacy = True
self.legacy = legacy
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/tokenization_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ def __init__(self, *args, **kwargs):
gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file"))
architecture = gguf_param["config"]["model_type"]
tokenizer_dict = gguf_param["tokenizer"]
fast_tokenizer = convert_gguf_tokenizer(architecture, tokenizer_dict)
fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict)

if len(additional_kwargs) > 0:
kwargs.update(additional_kwargs)

elif self.slow_tokenizer_class is not None:
# We need to create and convert a slow tokenizer to build the backend
slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
Expand Down Expand Up @@ -184,6 +188,7 @@ def __init__(self, *args, **kwargs):
tokens_to_add += [
token for token in self.all_special_tokens_extended if token not in encoder and token not in tokens_to_add
]

if len(tokens_to_add) > 0:
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
Expand Down
21 changes: 21 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class GgufIntegrationTests(unittest.TestCase):
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"

q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
Expand All @@ -43,6 +44,7 @@ class GgufIntegrationTests(unittest.TestCase):

q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -171,6 +173,25 @@ def test_qwen2_q4_0(self):
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_llama3_q4_0_tokenizer(self):
tokenizer_gguf = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
special_sentence = "สวัสดี"
predicted_text = tokenizer_gguf.decode(tokenizer_gguf.encode(special_sentence, return_tensors="pt")[0])
self.assertEqual(predicted_text, "<|begin_of_text|>" + special_sentence)

def test_llama3_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.llama3_model_id, gguf_file=self.q4_llama3_model_id, device_map="auto", torch_dtype=torch.float16
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I am new to this forum. I am"

self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down

0 comments on commit 6d43061

Please sign in to comment.