Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GGUF: Fix llama 3 GGUF #31358

Merged
merged 18 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 75 additions & 15 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from array import array

import numpy as np
from tokenizers import Tokenizer, decoders
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
from tokenizers.models import BPE

from .. import AddedToken
Expand Down Expand Up @@ -540,15 +540,26 @@ def __init__(self, dict_):
self.merges = merges
else:
self.merges = [tuple(merge.split(" ")) for merge in self.merges]
if not hasattr(self, "scores"):
self.scores = [None for _ in range(len(self.tokens))]

if not hasattr(self, "added_tokens"):
self.added_tokens = []

if not hasattr(self, "unk_token_id"):
self.unk_token_id = None

# Llama2 uses the field `unknown_token_id`
if hasattr(self, "unknown_token_id") and self.unk_token_id is None:
self.unk_token_id = self.unknown_token_id


class GGUFLlamaConverter(LlamaConverter):
def __init__(self, tokenizer_dict):
self.proto = GGUFTokenizerSkeleton(tokenizer_dict)
self.original_tokenizer = self.proto
self.additional_kwargs = {}
self.is_llama_3_tokenizer = getattr(self.proto, "tokenizer_type", "llama") != "llama"

def vocab(self, proto):
return list(zip(proto.tokens, proto.scores))
Expand All @@ -560,22 +571,50 @@ def tokenizer(self, proto):
vocab_scores = self.vocab(self.proto)
merges = self.merges(self.proto)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(bpe_vocab, merges, unk_token=proto.tokens[proto.unk_token_id], fuse_unk=True, byte_fallback=True)
)
tokenizer.add_special_tokens(
[
AddedToken("<unk>", normalized=False, special=True),
AddedToken("<s>", normalized=False, special=True),
AddedToken("</s>", normalized=False, special=True),
]
)

unk_token = proto.tokens[proto.unk_token_id] if proto.unk_token_id is not None else None
bos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "bos_token_id", None) is not None else None
eos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "eos_token_id", None) is not None else None

tokenizer = Tokenizer(BPE(bpe_vocab, merges, unk_token=unk_token, fuse_unk=True, byte_fallback=True))

special_tokens = []

if not hasattr(self.proto, "token_type"):
if unk_token is not None:
special_tokens.append(AddedToken(unk_token, normalized=False, special=True))

if bos_token is not None:
special_tokens.append(AddedToken(bos_token, normalized=False, special=True))

if eos_token is not None:
special_tokens.append(AddedToken(eos_token, normalized=False, special=True))
else:
# 3 stands for special tokens
special_tokens_idx = np.where(np.array(self.proto.token_type) == 3)[0]

for idx in special_tokens_idx:
special_tokens.append(AddedToken(self.proto.tokens[idx], normalized=False, special=True))

if len(special_tokens) != 0:
tokenizer.add_special_tokens(special_tokens)

if len(self.proto.added_tokens) != 0:
tokenizer.add_special_tokens(
[AddedToken(added_token, normalized=False, special=False) for added_token in self.added_tokens]
tokenizer.add_tokens(
[AddedToken(added_token, normalized=False, special=False) for added_token in self.proto.added_tokens]
)

self.additional_kwargs["unk_token"] = unk_token
self.additional_kwargs["eos_token"] = bos_token
self.additional_kwargs["bos_token"] = eos_token

if self.is_llama_3_tokenizer:
self.additional_kwargs["add_prefix_space"] = False
self.additional_kwargs["clean_up_tokenization_spaces"] = True

self.additional_kwargs["legacy"] = False
self.original_tokenizer.legacy = False

return tokenizer

def decoder(self, replacement, add_prefix_space):
Expand All @@ -584,14 +623,34 @@ def decoder(self, replacement, add_prefix_space):
decoders.Fuse(),
decoders.Replace("▁", " "),
]

if self.is_llama_3_tokenizer:
sequence += [decoders.ByteLevel(add_prefix_space=False, trim_offsets=False, use_regex=True)]

if add_prefix_space:
sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence)

def converted(self):
tokenizer = super().converted()

# HACK: patch the llama-3 tokenizer to use the correspinding pre-tokenizer
# and normalizer
if self.is_llama_3_tokenizer:
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(
add_prefix_space=False, trim_offsets=False, use_regex=True
)
# This is tricky as the additional kwargs are passed after legacy is force-set in LlamaTokenizer's
# init.
tokenizer.normalizer = normalizers.Sequence([])

return tokenizer


class GGUFQwen2Converter(Qwen2Converter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
self.additional_kwargs = {}

def converted(self) -> Tokenizer:
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
Expand Down Expand Up @@ -629,5 +688,6 @@ def convert_gguf_tokenizer(architecture, tokenizer_dict) -> Tokenizer:
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
"""
tokenizer_class_name = architecture
converter_class = GGUF_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(tokenizer_dict).converted()
converter = GGUF_TO_FAST_CONVERTERS[tokenizer_class_name](tokenizer_dict)
fast_tokenizer = converter.converted()
return fast_tokenizer, converter.additional_kwargs
3 changes: 2 additions & 1 deletion src/transformers/models/llama/tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def __init__(
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565"
" https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file"
" you can ignore this message"
)
legacy = True

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/llama/tokenization_llama_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def __init__(
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565"
" https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file"
" you can ignore this message."
)
legacy = True
self.legacy = legacy
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/tokenization_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ def __init__(self, *args, **kwargs):
gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file"))
architecture = gguf_param["config"]["model_type"]
tokenizer_dict = gguf_param["tokenizer"]
fast_tokenizer = convert_gguf_tokenizer(architecture, tokenizer_dict)
fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sound good to me TBH


if len(additional_kwargs) > 0:
kwargs.update(additional_kwargs)

elif self.slow_tokenizer_class is not None:
# We need to create and convert a slow tokenizer to build the backend
slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
Expand Down Expand Up @@ -184,6 +188,7 @@ def __init__(self, *args, **kwargs):
tokens_to_add += [
token for token in self.all_special_tokens_extended if token not in encoder and token not in tokens_to_add
]

if len(tokens_to_add) > 0:
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
Expand Down
21 changes: 21 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class GgufIntegrationTests(unittest.TestCase):
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"

q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
Expand All @@ -43,6 +44,7 @@ class GgufIntegrationTests(unittest.TestCase):

q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -171,6 +173,25 @@ def test_qwen2_q4_0(self):
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_llama3_q4_0_tokenizer(self):
tokenizer_gguf = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
special_sentence = "สวัสดี"
predicted_text = tokenizer_gguf.decode(tokenizer_gguf.encode(special_sentence, return_tensors="pt")[0])
self.assertEqual(predicted_text, "<|begin_of_text|>" + special_sentence)

def test_llama3_q4_0(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string that you are testing could be improved! Specifically to take into account the special tokens and etc.
But alright otherwise!

tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.llama3_model_id, gguf_file=self.q4_llama3_model_id, device_map="auto", torch_dtype=torch.float16
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I am new to this forum. I am"

self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down
Loading