From fb00288ffc627171f60107cd9c7ffbeeb6388e78 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 19 Apr 2024 17:37:49 +0200 Subject: [PATCH 01/29] Adds support for loading GGUF files Co-authored-by: Younes Belkada Co-authored-by: 99991 <99991@users.noreply.github.com> --- docker/transformers-all-latest-gpu/Dockerfile | 4 + docs/source/en/gguf.md | 94 +++++ src/transformers/configuration_utils.py | 15 +- src/transformers/convert_slow_tokenizer.py | 5 +- src/transformers/integrations/__init__.py | 16 + src/transformers/integrations/ggml.py | 387 ++++++++++++++++++ .../modeling_gguf_pytorch_utils.py | 160 ++++++++ src/transformers/modeling_utils.py | 78 +++- .../models/auto/tokenization_auto.py | 14 +- src/transformers/testing_utils.py | 8 + src/transformers/tokenization_utils_base.py | 83 ++-- src/transformers/tokenization_utils_fast.py | 7 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 + tests/quantization/ggml/__init__.py | 0 tests/quantization/ggml/test_ggml.py | 107 +++++ utils/check_inits.py | 1 + 17 files changed, 930 insertions(+), 55 deletions(-) create mode 100644 docs/source/en/gguf.md create mode 100644 src/transformers/integrations/ggml.py create mode 100644 src/transformers/modeling_gguf_pytorch_utils.py create mode 100644 tests/quantization/ggml/__init__.py create mode 100644 tests/quantization/ggml/test_ggml.py diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 4f596c3c1cf9a4..d2656274485640 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -45,6 +45,10 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt # For video model testing RUN python3 -m pip install --no-cache-dir decord av==9.2.0 +# For GGUF tests +RUN python3 -m pip install --no-cache-dir gguf + + # For `dinat` model # The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent) RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md new file mode 100644 index 00000000000000..f766a00d4ae137 --- /dev/null +++ b/docs/source/en/gguf.md @@ -0,0 +1,94 @@ + + +# GGUF and interaction with Transformers + +The GGUF file format is used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and other +libraries that depend on it, like the very popular [llama.cpp](https://github.com/ggerganov/llama.cpp) or +[whisper.cpp](https://github.com/ggerganov/whisper.cpp). + +It is a file format [supported by the Hugging Face Hub](https://huggingface.co/docs/hub/en/gguf) with features +allowing for quick inspection of tensors and metadata within the file. + +This file format is designed as a "single-file-format" where a single file usually contains both the configuration +attributes, the tokenizer vocabulary and other attributes, as well as all tensors to be loaded in the model. These +files come in different formats according to the quantization type of the file. We briefly go over some of them +[here](https://huggingface.co/docs/hub/en/gguf#quantization-types). + +## Support within Transformers + +We have added the ability to load `gguf` files within `transformers` in order to offer further training/fine-tuning +capabilities to gguf models, before converting back those models to `gguf` to use within the `ggml` ecosystem. When +loading a model, we first dequantize it to fp32, before loading the weights to be used in PyTorch. + +> [!NOTE] +> The support is still very exploratory and we welcome contributions in order to solidify it across quantization types +> and model architectures. + +For now, here are the supported model architectures and quantization types: + +### Supported quantization types + +The initial supported quantization types are decided according to the popular quantized files that have been shared +on the Hub. + +- F32 +- Q4_0 +- Q4_K +- Q6_K +- Q8_0 + +We take example from the excellent [99991/pygguf](https://github.com/99991/pygguf) Python parser to dequantize the +weights. + +### Supported model architectures + +For now the supported model architectures are the architectures that have been very popular on the Hub, namely: + +- LLaMa +- Mistral +- Gemma + +## Example usage + +In order to load `gguf` files in `transformers`, you should specify the `from_gguf` argument to the `from_pretrained` +methods of both tokenizers and models. Here is how one would load a tokenizer and a model, which can be loaded +from the exact same file: + +```py +from transformers import AutoTokenizer, AutoModelForCausalLM + +model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" +filename = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf" + +tokenizer = AutoTokenizer.from_pretrained(model_id, from_gguf=filename) +model = AutoModelForCausalLM.from_pretrained(model_id, from_gguf=filename) +``` + +Now you have access to the full, unquantized version of the model in the PyTorch ecosystem, where you can combine it +with a plethora of other tools. + +In order to convert back to a `gguf` file, we recommend using the +[`convert-hf-to-gguf.py` file](https://github.com/ggerganov/llama.cpp/blob/master/convert-hf-to-gguf.py) from llama.cpp. + +Here's how you would complete the script above to save the model and export it back to `gguf`: + +```py +tokenizer.save_pretrained('directory') +model.save_pretrained('directory') + +!python ${path_to_llama_cpp}/convert-hf-to-gguf.py ${directory} +``` \ No newline at end of file diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index dd2ed9d695e73b..f277c7ebfabfe4 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -27,6 +27,7 @@ from . import __version__ from .dynamic_module_utils import custom_object_save +from .modeling_gguf_pytorch_utils import load_gguf_checkpoint from .utils import ( CONFIG_NAME, PushToHubMixin, @@ -658,6 +659,8 @@ def _get_config_dict( from_auto_class = kwargs.pop("_from_auto", False) commit_hash = kwargs.pop("_commit_hash", None) + from_gguf = kwargs.get("from_gguf", None) + if trust_remote_code is True: logger.warning( "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" @@ -676,10 +679,10 @@ def _get_config_dict( resolved_config_file = pretrained_model_name_or_path is_local = True elif is_remote_url(pretrained_model_name_or_path): - configuration_file = pretrained_model_name_or_path + configuration_file = pretrained_model_name_or_path if from_gguf is None else from_gguf resolved_config_file = download_url(pretrained_model_name_or_path) else: - configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) + configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if from_gguf is None else from_gguf try: # Load from local folder or from cache or download from model Hub and cache @@ -712,8 +715,12 @@ def _get_config_dict( ) try: - # Load config dict - config_dict = cls._dict_from_json_file(resolved_config_file) + if from_gguf is None: + # Load config dict + config_dict = cls._dict_from_json_file(resolved_config_file) + else: + config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"] + config_dict["_commit_hash"] = commit_hash except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError( diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 88f9e5f19a5c06..df7ce2f11672b8 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1394,9 +1394,8 @@ def tokenizer(self, proto): def normalizer(self, proto): sequence = [] - if hasattr(self.original_tokenizer, "add_prefix_space"): - if self.original_tokenizer.add_prefix_space: - sequence += [normalizers.Prepend(prepend="▁")] + if getattr(self.original_tokenizer, "add_prefix_space", True): + sequence += [normalizers.Prepend(prepend="▁")] sequence += [normalizers.Replace(pattern=" ", content="▁")] return normalizers.Sequence(sequence) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 0dc2975aa963e1..dc4397938660ae 100644 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -42,6 +42,14 @@ "set_hf_deepspeed_config", "unset_hf_deepspeed_config", ], + "ggml": [ + "GGUF_CONFIG_MAPPING", + "GGUF_TENSOR_MAPPING", + "GGUF_TOKENIZER_MAPPING", + "_gguf_parse_value", + "load_dequant_gguf_tensor", + "load_gguf", + ], "integration_utils": [ "INTEGRATION_TO_CALLBACK", "AzureMLCallback", @@ -111,6 +119,14 @@ set_hf_deepspeed_config, unset_hf_deepspeed_config, ) + from .ggml import ( + GGUF_CONFIG_MAPPING, + GGUF_TENSOR_MAPPING, + GGUF_TOKENIZER_MAPPING, + _gguf_parse_value, + load_dequant_gguf_tensor, + load_gguf, + ) from .integration_utils import ( INTEGRATION_TO_CALLBACK, AzureMLCallback, diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py new file mode 100644 index 00000000000000..59423f2fc6d0bd --- /dev/null +++ b/src/transformers/integrations/ggml.py @@ -0,0 +1,387 @@ +# coding=utf-8 +# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991) +# https://github.com/99991/pygguf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Integration with GGML / The file is copied and adapted from https://github.com/99991/pygguf +with extra methods beings exposed +""" +from array import array + +import numpy as np +from tokenizers import Tokenizer, decoders +from tokenizers.models import BPE + +from .. import AddedToken +from ..convert_slow_tokenizer import LlamaConverter + + +# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md +GGML_TYPES = { + "F32": 0, + "Q4_0": 2, + "Q8_0": 8, + "Q4_K": 12, + "Q6_K": 14, +} + +# The Blocksizes are reported in bytes +# Check out: https://github.com/ggerganov/llama.cpp/blob/8a56075b07a8b571bf95a912ffdce4c928c2b414/gguf-py/gguf/constants.py#L801 +GGML_BLOCK_SIZES = { + "Q8_0": 2 + 32, # Q8_0 uses a blocksize of 32 (int8 tensors) + 2 bytes allocated for the scales + "Q4_K": 144, + "Q4_0": 2 + + 16, # Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales + "Q6_K": 210, +} + +# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md +DATA_TYPES = { + "uint32": 4, + "int32": 5, + "float32": 6, + "bool": 7, + "string": 8, + "array": 9, + "uint64": 10, +} + +GGUF_TENSOR_MAPPING = { + "llama": { + "token_embd": "model.embed_tokens", + "blk": "model.layers", + "ffn_up": "mlp.up_proj", + "ffn_down": "mlp.down_proj", + "ffn_gate": "mlp.gate_proj", + "ffn_norm": "post_attention_layernorm", + "attn_norm": "input_layernorm", + "attn_q": "self_attn.q_proj", + "attn_v": "self_attn.v_proj", + "attn_k": "self_attn.k_proj", + "attn_output": "self_attn.o_proj", + "output.weight": "lm_head.weight", + "output_norm": "model.norm", + }, + "mistral": { + "token_embd": "model.embed_tokens", + "blk": "model.layers", + "ffn_up": "mlp.up_proj", + "ffn_down": "mlp.down_proj", + "ffn_gate": "mlp.gate_proj", + "ffn_norm": "post_attention_layernorm", + "attn_norm": "input_layernorm", + "attn_q": "self_attn.q_proj", + "attn_v": "self_attn.v_proj", + "attn_k": "self_attn.k_proj", + "attn_output": "self_attn.o_proj", + "output.weight": "lm_head.weight", + "output_norm": "model.norm", + }, +} + + +GGUF_CONFIG_MAPPING = { + "general": { + "architecture": "model_type", + "name": "_model_name_or_path", + }, + "llama": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "rope.freq_base": "rope_theta", + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_rms_epsilon": "rms_norm_eps", + "vocab_size": "vocab_size", + }, + "mistral": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "rope.freq_base": "rope_theta", + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_rms_epsilon": "rms_norm_eps", + "vocab_size": "vocab_size", + }, + "tokenizer": { + "ggml.model": "model_type", + "ggml.bos_token_id": "bos_token_id", + "ggml.eos_token_id": "eos_token_id", + "ggml.unknown_token_id": "unk_token_id", + "ggml.padding_token_id": "pad_token_id", + }, +} + +GGUF_TOKENIZER_MAPPING = { + "tokenizer": { + "ggml.model": "tokenizer_type", + "ggml.tokens": "tokens", + "ggml.scores": "scores", + "ggml.token_type": "token_type", + "ggml.merges": "merges", + "ggml.bos_token_id": "bos_token_id", + "ggml.eos_token_id": "eos_token_id", + "ggml.unknown_token_id": "unk_token_id", + "ggml.padding_token_id": "pad_token_id", + "ggml.add_space_prefix": "add_prefix_space", + }, + "tokenizer_config": { + "chat_template": "chat_template", + "ggml.model": "model_type", + "ggml.bos_token_id": "bos_token_id", + "ggml.eos_token_id": "eos_token_id", + "ggml.unknown_token_id": "unk_token_id", + "ggml.padding_token_id": "pad_token_id", + }, +} + + +def _gguf_parse_value(_value, data_type): + if not isinstance(data_type, list): + data_type = [data_type] + if len(data_type) == 1: + data_type = data_type[0] + array_data_type = None + else: + if data_type[0] != 9: + raise ValueError("Received multiple types, but therefore expect the first type to indicate an array.") + data_type, array_data_type = data_type + + if data_type in [0, 1, 2, 3, 4, 5, 10, 11]: + _value = int(_value[0]) + elif data_type in [6, 12]: + _value = float(_value[0]) + elif data_type in [7]: + _value = bool(_value[0]) + elif data_type in [8]: + _value = array("B", list(_value)).tobytes().decode() + elif data_type in [9]: + _value = _gguf_parse_value(_value, array_data_type) + return _value + + +def dequantize_q4_k(data): + # C implementation + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929 + # C struct definition + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116 + block_size = GGML_BLOCK_SIZES["Q4_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + + # Casting to float32 because float16 is very slow on CPU + scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32) + scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32) + qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1) + qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32) + + # Dequantize scales and offsets (6 bits and 4 + 2 bits) + factors = scale_factors * np.concatenate( + [qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1 + ) + offsets = scale_offsets * np.concatenate( + [qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1 + ) + + # Interleave low and high quantized bits + qs2 = np.stack([qs2 & 0xF, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32) + # Dequantize final weights using scales and offsets + return factors * qs2 - offsets + + +def dequantize_q4_0(data): + # C implementation + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1086 + # C struct definition + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L11 + block_size = GGML_BLOCK_SIZES["Q4_0"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + + # The scales are stored on the first 2 bytes and the rest corresponds to the quants + scales = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32) + # scales = np.nan_to_num(scales) + # the rest of the bytes corresponds to the quants - we discard the first two bytes + quants = data_u8[:, 2:] + + ql = (quants[:, :] & 0xF).astype(np.int8) - 8 + qr = (quants[:, :] >> 4).astype(np.int8) - 8 + + # Use hstack + quants = np.hstack([ql, qr]) + + return (scales * quants).astype(np.float32) + + +def dequantize_q6_k(data): + # C implementation + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275 + # C struct definition + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152 + block_size = GGML_BLOCK_SIZES["Q6_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size) + + scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32) + + # TODO use uint8 and cast later? + ql = data_u8[:, :128].astype(np.int16) + qh = data_u8[:, 128:192].astype(np.int16) + sc = data_i8[:, 192:208, np.newaxis].astype(np.float32) + + # Unpack bits, subtraction requires signed data type + q1 = (ql[:, :32] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32 + q2 = (ql[:, 32:64] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32 + q3 = (ql[:, :32] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32 + q4 = (ql[:, 32:64] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32 + q5 = (ql[:, 64:96] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32 + q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32 + q7 = (ql[:, 64:96] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32 + q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32 + + # Dequantize + return scales * np.concatenate( + [ + sc[:, 0] * q1[:, :16], + sc[:, 1] * q1[:, 16:], + sc[:, 2] * q2[:, :16], + sc[:, 3] * q2[:, 16:], + sc[:, 4] * q3[:, :16], + sc[:, 5] * q3[:, 16:], + sc[:, 6] * q4[:, :16], + sc[:, 7] * q4[:, 16:], + sc[:, 8] * q5[:, :16], + sc[:, 9] * q5[:, 16:], + sc[:, 10] * q6[:, :16], + sc[:, 11] * q6[:, 16:], + sc[:, 12] * q7[:, :16], + sc[:, 13] * q7[:, 16:], + sc[:, 14] * q8[:, :16], + sc[:, 15] * q8[:, 16:], + ], + axis=1, + ) + + +def dequantize_q8_0(data): + # C struct definition + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 + block_size = GGML_BLOCK_SIZES["Q8_0"] + num_blocks = len(data) // block_size + + scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32) + qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] + + return scales * qs + + +def load_dequant_gguf_tensor(shape, ggml_type, data): + if ggml_type == GGML_TYPES["F32"]: + values = data + elif ggml_type == GGML_TYPES["Q8_0"]: + values = dequantize_q8_0(data) + elif ggml_type == GGML_TYPES["Q4_0"]: + values = dequantize_q4_0(data) + elif ggml_type == GGML_TYPES["Q4_K"]: + values = dequantize_q4_k(data) + elif ggml_type == GGML_TYPES["Q6_K"]: + values = dequantize_q6_k(data) + else: + raise NotImplementedError( + f"ggml_type {ggml_type} not implemented - please raise an issue on huggingface transformers: https://github.com/huggingface/transformers/issues/new/choose" + ) + + return values.reshape(shape[::-1]) + + +class GGUFTokenizerSkeleton: + def __init__(self, dict_): + for k, v in dict_.items(): + setattr(self, k, v) + + +class GGUFLlamaConverter(LlamaConverter): + def __init__(self, tokenizer_dict): + self.proto = GGUFTokenizerSkeleton(tokenizer_dict) + self.original_tokenizer = self.proto + + def vocab(self, proto): + return list(zip(proto.tokens, proto.scores)) + + def merges(self, proto): + return [tuple(merge.split(" ")) for merge in proto.merges] + + def tokenizer(self, proto): + vocab_scores = self.vocab(self.proto) + merges = self.merges(self.proto) + bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)} + tokenizer = Tokenizer( + BPE(bpe_vocab, merges, unk_token=proto.tokens[proto.unk_token_id], fuse_unk=True, byte_fallback=True) + ) + tokenizer.add_special_tokens( + [ + AddedToken("", normalized=False, special=True), + AddedToken("", normalized=False, special=True), + AddedToken("", normalized=False, special=True), + ] + ) + return tokenizer + + def decoder(self, replacement, add_prefix_space): + sequence = [ + decoders.ByteFallback(), + decoders.Fuse(), + decoders.Replace("▁", " "), + ] + print(add_prefix_space) + add_prefix_space = False + if add_prefix_space: + sequence += [decoders.Strip(content=" ", left=1)] + return decoders.Sequence(sequence) + + +GGUF_TO_FAST_CONVERTERS = { + "llama": GGUFLlamaConverter, +} + + +def convert_gguf_tokenizer(tokenizer_dict) -> Tokenizer: + """ + Utilities to convert a slow tokenizer instance in a fast tokenizer instance. + + Args: + transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]): + Instance of a slow tokenizer to convert in the backend tokenizer for + [`~tokenization_utils_base.PreTrainedTokenizerFast`]. + + Return: + A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a + [`~tokenization_utils_base.PreTrainedTokenizerFast`] + """ + tokenizer_class_name = tokenizer_dict["tokenizer_type"] + converter_class = GGUF_TO_FAST_CONVERTERS[tokenizer_class_name] + return converter_class(tokenizer_dict).converted() diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py new file mode 100644 index 00000000000000..b8721f3ee515f3 --- /dev/null +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -0,0 +1,160 @@ +# coding=utf-8 +# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991) +# https://github.com/99991/pygguf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from tqdm import tqdm + +from .integrations import ( + GGUF_CONFIG_MAPPING, + GGUF_TENSOR_MAPPING, + GGUF_TOKENIZER_MAPPING, + _gguf_parse_value, + load_dequant_gguf_tensor, +) +from .utils.logging import get_logger + + +logger = get_logger(__name__) + + +GGUF_TO_TRANSFORMERS_MAPPING = { + "ignore": { + "GGUF": { + "version": "version", + "tensor_count": "tensor_count", + "kv_count": "kv_count", + }, + "general": {"file_type": "file_type", "quantization_version": "quantization_version"}, + }, + "config": GGUF_CONFIG_MAPPING, + "tensors": GGUF_TENSOR_MAPPING, + "tokenizer": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer"]}, + "tokenizer_config": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer_config"]}, +} + +GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys()) + + +def read_field(reader, field): + value = reader.fields[field] + return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data] + + +def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=True): + """ + Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed + tokenizer and config attributes. + + Args: + gguf_checkpoint_path (`str`): + The path the to GGUF file to load + return_tensors (`bool`, defaults to `True`): + Whether to read the tensors from the file and return them. Not doing so is faster + and only loads the metadata in memory. + """ + try: + from gguf import GGUFReader + except (ImportError, ModuleNotFoundError): + logger.error( + "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF to be installed. Please see " + "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." + ) + raise + + reader = GGUFReader(gguf_checkpoint_path) + fields = reader.fields + reader_keys = list(fields.keys()) + + parsed_parameters = {k: {} for k in GGUF_TO_TRANSFORMERS_MAPPING} + + architecture = read_field(reader, "general.architecture")[0] + model_name = read_field(reader, "general.name") + + if "llama" in architecture and "mistral" in model_name: + updated_architecture = "mistral" + else: + updated_architecture = architecture + + if architecture not in GGUF_SUPPORTED_ARCHITECTURES: + raise ValueError(f"Architecture {architecture} not supported") + + # List all key-value pairs in a columnized format + for gguf_key, field in reader.fields.items(): + gguf_key = gguf_key.replace(architecture, updated_architecture) + split = gguf_key.split(".") + prefix = split[0] + config_key = ".".join(split[1:]) + + value = [_gguf_parse_value(field.parts[_data_index], field.types) for _data_index in field.data] + + if len(value) == 1: + value = value[0] + + if isinstance(value, str) and architecture in value: + value = value.replace(architecture, updated_architecture) + + for parameter in GGUF_TO_TRANSFORMERS_MAPPING: + parameter_renames = GGUF_TO_TRANSFORMERS_MAPPING[parameter] + if prefix in parameter_renames and config_key in parameter_renames[prefix]: + renamed_config_key = parameter_renames.get(prefix, {}).get(config_key) + if renamed_config_key == -1: + continue + + if renamed_config_key is not None: + parsed_parameters[parameter][renamed_config_key] = value + + if gguf_key in reader_keys: + reader_keys.remove(gguf_key) + + if gguf_key in reader_keys: + logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}") + + if return_tensors: + tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture] + + for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."): + renamed_tensor_name = tensor.name + + for tensor_name_mapping in GGUF_TO_TRANSFORMERS_MAPPING["tensors"]: + if tensor_name_mapping in renamed_tensor_name: + renamed_tensor_name = renamed_tensor_name.replace( + tensor_name_mapping, GGUF_TO_TRANSFORMERS_MAPPING["tensors"][tensor_name_mapping] + ) + + shape = tensor.shape + name = tensor.name + + weights = load_dequant_gguf_tensor(shape=shape, ggml_type=tensor.tensor_type, data=tensor.data) + + if architecture == "llama" and (".attn_k." in name or ".attn_q." in name): + num_heads = parsed_parameters["config"]["num_attention_heads"] + tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0]) + weights = weights.reshape(tmp_shape) + weights = weights.transpose(0, 2, 1, 3) + weights = weights.reshape(shape[::-1]) + + for tensor_name in tensor_key_mapping: + if tensor_name in name: + name = name.replace(tensor_name, tensor_key_mapping[tensor_name]) + + # Use copy to avoid errors with numpy and pytorch + parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) + + if len(reader_keys) > 0: + logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}") + + return parsed_parameters \ No newline at end of file diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e4fcd2ebc11e6e..578607ede42ee8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2984,6 +2984,10 @@ def from_pretrained( adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + from_gguf = kwargs.pop("from_gguf", None) + # Cache path to the GGUF file + gguf_path = None + if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -3147,6 +3151,7 @@ def from_pretrained( kwarg_attn_imp = kwargs.pop("attn_implementation", None) if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: config._attn_implementation = kwarg_attn_imp + model_kwargs = kwargs pre_quantized = getattr(config, "quantization_config", None) is not None @@ -3185,7 +3190,12 @@ def from_pretrained( keep_in_fp32_modules = None use_keep_in_fp32_modules = False - if pretrained_model_name_or_path is not None: + if from_gguf is not None and hf_quantizer is not None: + raise ValueError( + "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not loaded a quantized model from the Hub." + ) + + if pretrained_model_name_or_path is not None and from_gguf is None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: @@ -3427,6 +3437,36 @@ def from_pretrained( resolved_archive_file = archive_file else: logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + elif from_gguf is not None: + from .modeling_gguf_pytorch_utils import load_gguf_checkpoint + + # Case 1: the GGUF file is present locally + if os.path.isfile(from_gguf): + gguf_path = from_gguf + # Case 2: The GGUF path is a location on the Hub + # Load from URL or cache if already cached + else: + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + gguf_path = cached_file(pretrained_model_name_or_path, from_gguf, **cached_file_kwargs) + + state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True)["tensors"] + + resolved_archive_file = None + is_sharded = False else: resolved_archive_file = None @@ -3521,7 +3561,8 @@ def from_pretrained( loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_state_dict_keys = list(state_dict.keys()) - if low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()): + + if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())): # In case some weights need to be kept in float32 and accelerate is not installed, # we later on want to take the path where state_dict is not None, that is the one # that do not require accelerate. @@ -3667,6 +3708,7 @@ def from_pretrained( # restore default dtype if dtype_orig is not None: torch.set_default_dtype(dtype_orig) + ( model, missing_keys, @@ -3690,6 +3732,7 @@ def from_pretrained( dtype=torch_dtype, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, + gguf_path=gguf_path, ) # make sure token embedding weights are still tied if needed @@ -3776,9 +3819,12 @@ def _load_pretrained_model( dtype=None, hf_quantizer=None, keep_in_fp32_modules=None, + gguf_path=None, ): is_safetensors = False is_quantized = hf_quantizer is not None + state_dict_folder = None + state_dict_index = None if device_map is not None and "disk" in device_map.values(): archive_file = ( @@ -4036,6 +4082,8 @@ def _find_mismatched_keys( for p, f in weight_map.items() if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" } + else: + offload_index = None if state_dict is not None: # Whole checkpoint @@ -4047,11 +4095,29 @@ def _find_mismatched_keys( remove_prefix_from_model, ignore_mismatched_sizes, ) - error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) - offload_index = None - else: - # Sharded checkpoint or whole but low_cpu_mem_usage==True + if gguf_path is None: + error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + # For GGUF models `state_dict` is never set to None as the state dict is always small + else: + error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + loaded_keys, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + hf_quantizer=hf_quantizer, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + ) + else: # This should always be a list but, just to be sure. if not isinstance(resolved_archive_file, list): resolved_archive_file = [resolved_archive_file] diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 99706afe1655e3..9e1984cfe9591c 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -23,6 +23,7 @@ from ...configuration_utils import PretrainedConfig from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE from ...utils import ( @@ -770,6 +771,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): use_fast = kwargs.pop("use_fast", True) tokenizer_type = kwargs.pop("tokenizer_type", None) trust_remote_code = kwargs.pop("trust_remote_code", None) + from_gguf = kwargs.get("from_gguf", None) # First, let's see whether the tokenizer_type is passed so that we can leverage it if tokenizer_type is not None: @@ -816,9 +818,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): # If that did not work, let's try to use the config. if config_tokenizer_class is None: if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs - ) + if from_gguf is None: + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + else: + gguf_path = cached_file(pretrained_model_name_or_path, from_gguf, **kwargs) + config_dict = load_gguf_checkpoint(gguf_path, return_tensors=True)["config"] + config = AutoConfig.for_model(**config_dict) config_tokenizer_class = config.tokenizer_class if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: tokenizer_auto_map = config.auto_map["AutoTokenizer"] @@ -876,6 +883,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): model_type = config_class_to_model_type(type(config).__name__) if model_type is not None: tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)] + if tokenizer_class_fast and (use_fast or tokenizer_class_py is None): return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) else: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 8297cb981ef1fb..7c6627682ef522 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -73,6 +73,7 @@ is_ftfy_available, is_g2p_en_available, is_galore_torch_available, + is_gguf_available, is_ipex_available, is_jieba_available, is_jinja_available, @@ -375,6 +376,13 @@ def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION): )(test_case) +def require_gguf(test_case): + """ + Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed. + """ + return unittest.skipUnless(is_gguf_available(), "test requires gguf")(test_case) + + def require_fsdp(test_case, min_version: str = "1.12.0"): """ Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed. diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index a30daf5f7fbe69..eb232f9e98d658 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1954,6 +1954,7 @@ def from_pretrained( from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) commit_hash = kwargs.pop("_commit_hash", None) + from_gguf = kwargs.get("from_gguf", False) if use_auth_token is not None: warnings.warn( @@ -1981,7 +1982,7 @@ def from_pretrained( is_local = os.path.isdir(pretrained_model_name_or_path) single_file_id = None if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): - if len(cls.vocab_files_names) > 1: + if len(cls.vocab_files_names) > 1 and not from_gguf: raise ValueError( f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " "supported for this tokenizer. Use a model identifier or the path to a directory instead." @@ -1996,42 +1997,45 @@ def from_pretrained( vocab_files[file_id] = pretrained_model_name_or_path single_file_id = file_id else: - # At this point pretrained_model_name_or_path is either a directory or a model identifier name - additional_files_names = { - "added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy - "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy - "tokenizer_config_file": TOKENIZER_CONFIG_FILE, - # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders - "tokenizer_file": FULL_TOKENIZER_FILE, - } - vocab_files = {**cls.vocab_files_names, **additional_files_names} - if "tokenizer_file" in vocab_files: - # Try to get the tokenizer config to see if there are versioned tokenizer files. - fast_tokenizer_file = FULL_TOKENIZER_FILE - resolved_config_file = cached_file( - pretrained_model_name_or_path, - TOKENIZER_CONFIG_FILE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - revision=revision, - local_files_only=local_files_only, - subfolder=subfolder, - user_agent=user_agent, - _raise_exceptions_for_gated_repo=False, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, - _commit_hash=commit_hash, - ) - commit_hash = extract_commit_hash(resolved_config_file, commit_hash) - if resolved_config_file is not None: - with open(resolved_config_file, encoding="utf-8") as reader: - tokenizer_config = json.load(reader) - if "fast_tokenizer_files" in tokenizer_config: - fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) - vocab_files["tokenizer_file"] = fast_tokenizer_file + if not from_gguf: + # At this point pretrained_model_name_or_path is either a directory or a model identifier name + additional_files_names = { + "added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy + "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy + "tokenizer_config_file": TOKENIZER_CONFIG_FILE, + # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders + "tokenizer_file": FULL_TOKENIZER_FILE, + } + vocab_files = {**cls.vocab_files_names, **additional_files_names} + if "tokenizer_file" in vocab_files: + # Try to get the tokenizer config to see if there are versioned tokenizer files. + fast_tokenizer_file = FULL_TOKENIZER_FILE + resolved_config_file = cached_file( + pretrained_model_name_or_path, + TOKENIZER_CONFIG_FILE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + subfolder=subfolder, + user_agent=user_agent, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + _commit_hash=commit_hash, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + if resolved_config_file is not None: + with open(resolved_config_file, encoding="utf-8") as reader: + tokenizer_config = json.load(reader) + if "fast_tokenizer_files" in tokenizer_config: + fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) + vocab_files["tokenizer_file"] = fast_tokenizer_file + else: + vocab_files["vocab_file"] = from_gguf # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} @@ -2070,7 +2074,7 @@ def from_pretrained( "files are necessary for the tokenizer to operate." ) - if all(full_file_name is None for full_file_name in resolved_vocab_files.values()): + if all(full_file_name is None for full_file_name in resolved_vocab_files.values()) and not from_gguf: raise EnvironmentError( f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " @@ -2119,8 +2123,9 @@ def _from_pretrained( # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json # file or if `from_slow` is set to True. from_slow = kwargs.get("from_slow", False) + from_gguf = kwargs.get("from_gguf", False) has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None - if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None: + if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not from_gguf: slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained( copy.deepcopy(resolved_vocab_files), pretrained_model_name_or_path, diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index b1daa1ec1be92f..bb9709c71b304d 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -29,6 +29,8 @@ from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer from .convert_slow_tokenizer import convert_slow_tokenizer +from .integrations.ggml import convert_gguf_tokenizer +from .modeling_gguf_pytorch_utils import load_gguf_checkpoint from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils_base import ( INIT_TOKENIZER_DOCSTRING, @@ -94,6 +96,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): def __init__(self, *args, **kwargs): tokenizer_object = kwargs.pop("tokenizer_object", None) slow_tokenizer = kwargs.pop("__slow_tokenizer", None) + from_gguf = kwargs.pop("from_gguf", None) fast_tokenizer_file = kwargs.pop("tokenizer_file", None) from_slow = kwargs.pop("from_slow", False) added_tokens_decoder = kwargs.pop("added_tokens_decoder", {}) @@ -112,6 +115,10 @@ def __init__(self, *args, **kwargs): elif slow_tokenizer is not None: # We need to convert a slow tokenizer to build the backend fast_tokenizer = convert_slow_tokenizer(slow_tokenizer) + elif from_gguf is not None: + # We need to convert a slow tokenizer to build the backend + tokenizer_dict = load_gguf_checkpoint(kwargs.get("vocab_file"))["tokenizer"] + fast_tokenizer = convert_gguf_tokenizer(tokenizer_dict) elif self.slow_tokenizer_class is not None: # We need to create and convert a slow tokenizer to build the backend slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 121c4dc1361e4e..ad1dc6e7465791 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -128,6 +128,7 @@ is_ftfy_available, is_g2p_en_available, is_galore_torch_available, + is_gguf_available, is_in_notebook, is_ipex_available, is_jieba_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a8c45aeac33f16..07451aec5e76c2 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -150,6 +150,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _scipy_available = _is_package_available("scipy") _sentencepiece_available = _is_package_available("sentencepiece") _is_seqio_available = _is_package_available("seqio") +_is_gguf_available = _is_package_available("gguf") _sklearn_available = importlib.util.find_spec("sklearn") is not None if _sklearn_available: try: @@ -799,6 +800,10 @@ def is_seqio_available(): return _is_seqio_available +def is_gguf_available(): + return _is_gguf_available + + def is_protobuf_available(): if importlib.util.find_spec("google") is None: return False diff --git a/tests/quantization/ggml/__init__.py b/tests/quantization/ggml/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py new file mode 100644 index 00000000000000..f276432f8f45ae --- /dev/null +++ b/tests/quantization/ggml/test_ggml.py @@ -0,0 +1,107 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device +from transformers.utils import is_torch_available + + +if is_torch_available(): + import torch + + +@require_gguf +@require_torch_gpu +@slow +class GgufIntegrationTests(unittest.TestCase): + original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" + mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" + + q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" + q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf" + q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf" + + q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf" + + example_text = "Hello" + + def test_q4_0(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q4_0_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q4_0_gguf_model_id).to(torch_device) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_q4_k_m(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q4_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q4_k_gguf_model_id).to(torch_device) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_q6_k(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q6_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q6_k_gguf_model_id).to(torch_device) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_q6_k_fp16(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q6_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained( + self.model_id, from_gguf=self.q6_k_gguf_model_id, torch_dtype=torch.float16 + ).to(torch_device) + + self.assertTrue(model.lm_head.weight.dtype == torch.float16) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_q8_0(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q8_0_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q8_0_gguf_model_id).to(torch_device) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, World!\n\n5. Use a library" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_mistral_q4_0(self): + tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, from_gguf=self.q4_0_mistral_model_id) + model = AutoModelForCausalLM.from_pretrained( + self.mistral_model_id, from_gguf=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16 + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = " Hello,\n\nI'm trying to create a" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) diff --git a/utils/check_inits.py b/utils/check_inits.py index b9a637e6354bba..19c23279b9b8d8 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -331,6 +331,7 @@ def get_transformers_submodules() -> List[str]: "models.esm.openfold_utils", "modeling_attn_mask_utils", "safetensors_conversion", + "modeling_gguf_pytorch_utils", ] From 81e4324104401813ff6ce68a8c1642c1e3feafe6 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Mon, 22 Apr 2024 11:32:57 +0200 Subject: [PATCH 02/29] add q2_k q3_k q5_k support from @99991 --- src/transformers/integrations/ggml.py | 148 +++++++++++++++++- .../modeling_gguf_pytorch_utils.py | 2 +- .../models/auto/tokenization_auto.py | 2 +- tests/quantization/ggml/test_ggml.py | 33 ++++ 4 files changed, 181 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 59423f2fc6d0bd..c4ea72b0d7a71e 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -32,7 +32,10 @@ "F32": 0, "Q4_0": 2, "Q8_0": 8, + "Q2_K": 10, + "Q3_K": 11, "Q4_K": 12, + "Q5_K": 13, "Q6_K": 14, } @@ -43,7 +46,10 @@ "Q4_K": 144, "Q4_0": 2 + 16, # Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales - "Q6_K": 210, + "Q6_K": 210, + "Q2_K":256 // 16 + 256 // 4 + 2 + 2, # See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9 + "Q3_K": 256 // 8 + 256 // 4 + 12 + 2, + "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, } # Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md @@ -298,6 +304,139 @@ def dequantize_q8_0(data): return scales * qs +def dequantize_q2_k(data): + # C implementation + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547 + # C struct definition + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74 + num_blocks = len(data) // GGML_BLOCK_SIZES["Q2_K"] + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"] // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"]) + + dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) + d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32) + scales = data_u8[:, :16].reshape(num_blocks, 16, 1) + qs = data_u8[:, 16:80].reshape(num_blocks, 64) + + tmp = np.stack([ + qs[:, 00:16] >> 0, + qs[:, 16:32] >> 0, + qs[:, 00:16] >> 2, + qs[:, 16:32] >> 2, + qs[:, 00:16] >> 4, + qs[:, 16:32] >> 4, + qs[:, 00:16] >> 6, + qs[:, 16:32] >> 6, + qs[:, 32:48] >> 0, + qs[:, 48:64] >> 0, + qs[:, 32:48] >> 2, + qs[:, 48:64] >> 2, + qs[:, 32:48] >> 4, + qs[:, 48:64] >> 4, + qs[:, 32:48] >> 6, + qs[:, 48:64] >> 6, + ], axis=1) + + return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) + +def dequantize_q3_k(data): + # C implementation + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42 + # C struct definition + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95 + num_blocks = len(data) // GGML_BLOCK_SIZES["Q3_K"] + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"] // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"]) + + d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) + bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little") + bits = 4 ^ (bits << 2) + qs = data_u8[:, 32:32 + 64].astype(np.int16) + a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2) + scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8) + scales[:, 0] = (a & 15) | ((c & 3) << 4) + scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4) + scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4) + scales[:, 3] = (b >> 4) | ((c >> 6) << 4) + scales = scales.reshape(num_blocks, 16, 1).astype(np.int16) + + return d * (scales - 32) * np.stack([ + (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]), + (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]), + (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]), + (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]), + (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]), + (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]), + (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]), + (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]), + (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]), + (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]), + (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]), + (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]), + (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]), + (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]), + (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]), + (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) + ], axis=1) + + +def dequantize_q5_k(data): + # C implementation + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129 + # C struct definition + # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138 + num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_K"] + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"] // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"]) + + d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32) + dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32) + scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1) + qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1) + qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32) + + bits = np.unpackbits(qh, axis=-1, bitorder="little") + + qs_hi_4 = qs >> 4 + qs_lo_4 = qs & 15 + + scales_lo_6 = scales[:, :8] & 63 + scales_hi_6 = scales[:, :8] >> 6 + scales_lo_4 = scales[:, 8:] & 15 + scales_hi_4 = scales[:, 8:] >> 4 + + m1 = dmin * scales_lo_6[:, 4] + m2 = dmin * scales_lo_6[:, 5] + m3 = dmin * scales_lo_6[:, 6] + m4 = dmin * scales_lo_6[:, 7] + m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4)) + m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4)) + m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4)) + m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4)) + + d1 = d * scales_lo_6[:, 0] + d2 = d * scales_lo_6[:, 1] + d3 = d * scales_lo_6[:, 2] + d4 = d * scales_lo_6[:, 3] + d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4)) + d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4)) + d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4)) + d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4)) + + return np.concatenate([ + d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1, + d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2, + d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3, + d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4, + d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5, + d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6, + d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7, + d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, + ], axis=1) + def load_dequant_gguf_tensor(shape, ggml_type, data): if ggml_type == GGML_TYPES["F32"]: @@ -310,6 +449,12 @@ def load_dequant_gguf_tensor(shape, ggml_type, data): values = dequantize_q4_k(data) elif ggml_type == GGML_TYPES["Q6_K"]: values = dequantize_q6_k(data) + elif ggml_type == GGML_TYPES["Q2_K"]: + values = dequantize_q2_k(data) + elif ggml_type == GGML_TYPES["Q3_K"]: + values = dequantize_q3_k(data) + elif ggml_type == GGML_TYPES["Q5_K"]: + values = dequantize_q5_k(data) else: raise NotImplementedError( f"ggml_type {ggml_type} not implemented - please raise an issue on huggingface transformers: https://github.com/huggingface/transformers/issues/new/choose" @@ -357,7 +502,6 @@ def decoder(self, replacement, add_prefix_space): decoders.Fuse(), decoders.Replace("▁", " "), ] - print(add_prefix_space) add_prefix_space = False if add_prefix_space: sequence += [decoders.Strip(content=" ", left=1)] diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index b8721f3ee515f3..1732569177f43e 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -54,7 +54,7 @@ def read_field(reader, field): return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data] -def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=True): +def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): """ Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config attributes. diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 9e1984cfe9591c..822a680bc4bf33 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -824,7 +824,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): ) else: gguf_path = cached_file(pretrained_model_name_or_path, from_gguf, **kwargs) - config_dict = load_gguf_checkpoint(gguf_path, return_tensors=True)["config"] + config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"] config = AutoConfig.for_model(**config_dict) config_tokenizer_class = config.tokenizer_class if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index f276432f8f45ae..bde527e1b6f2d5 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -33,6 +33,9 @@ class GgufIntegrationTests(unittest.TestCase): q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + q2_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf" + q3_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf" + q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf" q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf" q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf" @@ -40,6 +43,36 @@ class GgufIntegrationTests(unittest.TestCase): example_text = "Hello" + def test_q2_k(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q2_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q2_k_gguf_model_id).to(torch_device) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = " Hello, World!\n\n[10:0" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_q3_k(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q3_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q3_k_gguf_model_id).to(torch_device) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = " Hello, World!\n\n```\n<|user" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_q5_k(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q5_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q5_k_gguf_model_id).to(torch_device) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_q4_0(self): tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q4_0_gguf_model_id) model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q4_0_gguf_model_id).to(torch_device) From 8a0d5b8838837c11794f3f2f8faaaaf7724af061 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Mon, 22 Apr 2024 11:35:27 +0200 Subject: [PATCH 03/29] fix tests --- tests/quantization/ggml/test_ggml.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index bde527e1b6f2d5..28b56e3b5bc5a4 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -70,7 +70,7 @@ def test_q5_k(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" + EXPECTED_TEXT = " Hello, World!\n\nStep 3: Add" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q4_0(self): @@ -80,7 +80,7 @@ def test_q4_0(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" + EXPECTED_TEXT = " Hello, World!\n\nStep 3: Add" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q4_k_m(self): @@ -90,7 +90,7 @@ def test_q4_k_m(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n" + EXPECTED_TEXT = " Hello, World!\n\n5. Python:\n" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q6_k(self): @@ -100,7 +100,7 @@ def test_q6_k(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" + EXPECTED_TEXT = " Hello, World!\n\nStep 3: Add" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q6_k_fp16(self): @@ -114,7 +114,7 @@ def test_q6_k_fp16(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" + EXPECTED_TEXT = " Hello, World!\n\nStep 3: Add" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q8_0(self): @@ -124,7 +124,7 @@ def test_q8_0(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello, World!\n\n5. Use a library" + EXPECTED_TEXT = " Hello, World!\n\n5. Use a library" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_mistral_q4_0(self): From 08534f34477c735d6a2757e9c17121d0d1802e59 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 22 Apr 2024 14:18:55 +0200 Subject: [PATCH 04/29] Update doc --- docs/source/en/gguf.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index f766a00d4ae137..54e83978d9dec1 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -60,7 +60,6 @@ For now the supported model architectures are the architectures that have been v - LLaMa - Mistral -- Gemma ## Example usage @@ -91,4 +90,4 @@ tokenizer.save_pretrained('directory') model.save_pretrained('directory') !python ${path_to_llama_cpp}/convert-hf-to-gguf.py ${directory} -``` \ No newline at end of file +``` From ebd9944de7467f0c0a5307a33a2dc4952522f4ac Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 22 Apr 2024 14:20:05 +0200 Subject: [PATCH 05/29] Style --- src/transformers/integrations/ggml.py | 122 ++++++++++-------- .../modeling_gguf_pytorch_utils.py | 2 +- 2 files changed, 71 insertions(+), 53 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index c4ea72b0d7a71e..c6b424a5f40781 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -46,8 +46,11 @@ "Q4_K": 144, "Q4_0": 2 + 16, # Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales - "Q6_K": 210, - "Q2_K":256 // 16 + 256 // 4 + 2 + 2, # See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9 + "Q6_K": 210, + "Q2_K": 256 // 16 + + 256 // 4 + + 2 + + 2, # See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9 "Q3_K": 256 // 8 + 256 // 4 + 12 + 2, "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, } @@ -304,6 +307,7 @@ def dequantize_q8_0(data): return scales * qs + def dequantize_q2_k(data): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547 @@ -319,27 +323,31 @@ def dequantize_q2_k(data): scales = data_u8[:, :16].reshape(num_blocks, 16, 1) qs = data_u8[:, 16:80].reshape(num_blocks, 64) - tmp = np.stack([ - qs[:, 00:16] >> 0, - qs[:, 16:32] >> 0, - qs[:, 00:16] >> 2, - qs[:, 16:32] >> 2, - qs[:, 00:16] >> 4, - qs[:, 16:32] >> 4, - qs[:, 00:16] >> 6, - qs[:, 16:32] >> 6, - qs[:, 32:48] >> 0, - qs[:, 48:64] >> 0, - qs[:, 32:48] >> 2, - qs[:, 48:64] >> 2, - qs[:, 32:48] >> 4, - qs[:, 48:64] >> 4, - qs[:, 32:48] >> 6, - qs[:, 48:64] >> 6, - ], axis=1) + tmp = np.stack( + [ + qs[:, 00:16] >> 0, + qs[:, 16:32] >> 0, + qs[:, 00:16] >> 2, + qs[:, 16:32] >> 2, + qs[:, 00:16] >> 4, + qs[:, 16:32] >> 4, + qs[:, 00:16] >> 6, + qs[:, 16:32] >> 6, + qs[:, 32:48] >> 0, + qs[:, 48:64] >> 0, + qs[:, 32:48] >> 2, + qs[:, 48:64] >> 2, + qs[:, 32:48] >> 4, + qs[:, 48:64] >> 4, + qs[:, 32:48] >> 6, + qs[:, 48:64] >> 6, + ], + axis=1, + ) return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) + def dequantize_q3_k(data): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42 @@ -353,8 +361,8 @@ def dequantize_q3_k(data): d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little") bits = 4 ^ (bits << 2) - qs = data_u8[:, 32:32 + 64].astype(np.int16) - a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2) + qs = data_u8[:, 32 : 32 + 64].astype(np.int16) + a, b, c = data_u8[:, 96 : 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2) scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8) scales[:, 0] = (a & 15) | ((c & 3) << 4) scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4) @@ -362,24 +370,31 @@ def dequantize_q3_k(data): scales[:, 3] = (b >> 4) | ((c >> 6) << 4) scales = scales.reshape(num_blocks, 16, 1).astype(np.int16) - return d * (scales - 32) * np.stack([ - (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]), - (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]), - (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]), - (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]), - (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]), - (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]), - (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]), - (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]), - (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]), - (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]), - (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]), - (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]), - (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]), - (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]), - (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]), - (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) - ], axis=1) + return ( + d + * (scales - 32) + * np.stack( + [ + (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]), + (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]), + (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]), + (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]), + (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]), + (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]), + (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]), + (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]), + (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]), + (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]), + (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]), + (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]), + (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]), + (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]), + (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]), + (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]), + ], + axis=1, + ) + ) def dequantize_q5_k(data): @@ -395,8 +410,8 @@ def dequantize_q5_k(data): d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32) dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32) scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1) - qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1) - qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32) + qh = data_u8[:, 16 : 16 + 32].reshape(num_blocks, 32, 1) + qs = data_u8[:, 48 : 48 + 128].reshape(num_blocks, 4, 32) bits = np.unpackbits(qh, axis=-1, bitorder="little") @@ -426,16 +441,19 @@ def dequantize_q5_k(data): d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4)) d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4)) - return np.concatenate([ - d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1, - d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2, - d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3, - d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4, - d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5, - d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6, - d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7, - d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, - ], axis=1) + return np.concatenate( + [ + d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1, + d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2, + d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3, + d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4, + d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5, + d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6, + d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7, + d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, + ], + axis=1, + ) def load_dequant_gguf_tensor(shape, ggml_type, data): diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 1732569177f43e..6e11dc84820e37 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -157,4 +157,4 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): if len(reader_keys) > 0: logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}") - return parsed_parameters \ No newline at end of file + return parsed_parameters From 5c913ecc97cab4c058a7195419507e61099878b1 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 22 Apr 2024 17:58:13 +0200 Subject: [PATCH 06/29] Docs --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index edeb85fd6a4a88..324bd761bee15d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -137,6 +137,8 @@ title: Troubleshoot - local: hf_quantizer title: Contribute new quantization method + - local: gguf + title: Interoperability with GGUF files title: Developer guides - sections: - local: performance From c49f1a8d9dfacd42ba4d90a45af4c932da950cb6 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Mon, 22 Apr 2024 19:14:15 +0200 Subject: [PATCH 07/29] fix CI --- src/transformers/modeling_gguf_pytorch_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 6e11dc84820e37..c08b180827b6e5 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -15,7 +15,6 @@ # limitations under the License. import numpy as np -import torch from tqdm import tqdm from .integrations import ( @@ -25,9 +24,13 @@ _gguf_parse_value, load_dequant_gguf_tensor, ) +from .utils import is_torch_available from .utils.logging import get_logger +if is_torch_available(): + import torch + logger = get_logger(__name__) From 7fa538b319fc41c7aeed7aa7e91365a78492df66 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 22 Apr 2024 19:42:11 +0200 Subject: [PATCH 08/29] Update docs/source/en/gguf.md --- docs/source/en/gguf.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 54e83978d9dec1..304288e88136d0 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -46,6 +46,8 @@ The initial supported quantization types are decided according to the popular qu on the Hub. - F32 +- Q2_K +- Q3_K - Q4_0 - Q4_K - Q6_K From 548532773d5b580b6381744f13aa1db75f9e2eb7 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 22 Apr 2024 19:42:26 +0200 Subject: [PATCH 09/29] Update docs/source/en/gguf.md --- docs/source/en/gguf.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 304288e88136d0..d5ec642a1904e6 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -50,6 +50,7 @@ on the Hub. - Q3_K - Q4_0 - Q4_K +- Q5_K - Q6_K - Q8_0 From ca8363e8dfd36ffb3ad6a46a7494c184868d22fd Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 23 Apr 2024 11:54:01 +0200 Subject: [PATCH 10/29] Compute merges --- src/transformers/integrations/ggml.py | 31 ++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index c6b424a5f40781..68b1d67511b667 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -25,6 +25,11 @@ from .. import AddedToken from ..convert_slow_tokenizer import LlamaConverter +from ..utils import logging +from ..utils.logging import tqdm + + +logger = logging.get_logger(__name__) # Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md @@ -486,6 +491,30 @@ def __init__(self, dict_): for k, v in dict_.items(): setattr(self, k, v) + if not hasattr(self, "tokens") or not hasattr(self, "scores"): + raise ValueError("tokens and scores need to be passed for a LLaMa tokenizer to be instantiated.") + else: + tokens = self.tokens + scores = self.scores + vocab = {t: scores[i] for i, t in enumerate(tokens)} + + if not hasattr(self, "merges"): + logger.warning("Merges were not in checkpoint, building merges on the fly.") + merges = [] + for merge, piece_score in tqdm(vocab.items()): + local = [] + for index in range(1, len(merge)): + piece_l, piece_r = merge[:index], merge[index:] + if piece_l in tokens and piece_r in tokens: + local.append((piece_l, piece_r, piece_score)) + local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]), reverse=True) + merges.extend(local) + merges = sorted(merges, key=lambda val: val[2], reverse=True) + merges = [(val[0], val[1]) for val in merges] + self.merges = merges + else: + self.merges = [tuple(merge.split(" ")) for merge in self.merges] + class GGUFLlamaConverter(LlamaConverter): def __init__(self, tokenizer_dict): @@ -496,7 +525,7 @@ def vocab(self, proto): return list(zip(proto.tokens, proto.scores)) def merges(self, proto): - return [tuple(merge.split(" ")) for merge in proto.merges] + return proto.merges def tokenizer(self, proto): vocab_scores = self.vocab(self.proto) From e6c6f6ceff8839d105b4cdf27262575784ef3a87 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 30 Apr 2024 15:29:04 +0200 Subject: [PATCH 11/29] change logic --- src/transformers/tokenization_utils_base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index eb232f9e98d658..c2ee701229b460 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1997,7 +1997,9 @@ def from_pretrained( vocab_files[file_id] = pretrained_model_name_or_path single_file_id = file_id else: - if not from_gguf: + if from_gguf: + vocab_files["vocab_file"] = from_gguf + else: # At this point pretrained_model_name_or_path is either a directory or a model identifier name additional_files_names = { "added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy @@ -2033,9 +2035,7 @@ def from_pretrained( tokenizer_config = json.load(reader) if "fast_tokenizer_files" in tokenizer_config: fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) - vocab_files["tokenizer_file"] = fast_tokenizer_file - else: - vocab_files["vocab_file"] = from_gguf + vocab_files["tokenizer_file"] = fast_tokenizer_file # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} From a6cd08ce620861c4f95668e91d0a98aac3a62a97 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 30 Apr 2024 15:32:28 +0200 Subject: [PATCH 12/29] add comment for clarity --- src/transformers/tokenization_utils_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index c2ee701229b460..6147e1aae33de9 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2074,6 +2074,8 @@ def from_pretrained( "files are necessary for the tokenizer to operate." ) + # If one passes a GGUF file path to `from_gguf` there is no need for this check as the tokenizer will be + # loaded directly from the GGUF file. if all(full_file_name is None for full_file_name in resolved_vocab_files.values()) and not from_gguf: raise EnvironmentError( f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " From 66118772aa52a89450cae088cfbd43222cf4dabe Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 30 Apr 2024 15:33:11 +0200 Subject: [PATCH 13/29] add comment for clarity --- src/transformers/tokenization_utils_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 6147e1aae33de9..60de37ba3b6798 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2127,6 +2127,9 @@ def _from_pretrained( from_slow = kwargs.get("from_slow", False) from_gguf = kwargs.get("from_gguf", False) has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None + + # If one passes a GGUF file path to `from_gguf` there is no need for this check as the tokenizer will be + # loaded directly from the GGUF file. if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not from_gguf: slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained( copy.deepcopy(resolved_vocab_files), From 455163bef4b13d00066f3d72298299e8c8bb99ca Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 30 Apr 2024 15:34:39 +0200 Subject: [PATCH 14/29] Update src/transformers/models/auto/tokenization_auto.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/auto/tokenization_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 363e0a4f9f8361..9be485089d7be7 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -819,7 +819,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): # If that did not work, let's try to use the config. if config_tokenizer_class is None: if not isinstance(config, PretrainedConfig): - if from_gguf is None: + if from_gguf is None or not from_gguf: config = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs ) From 42d5815e9811c56de8e792d439d8b3d7ba3b0fc8 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 30 Apr 2024 15:38:26 +0200 Subject: [PATCH 15/29] change logic --- src/transformers/configuration_utils.py | 6 +++--- src/transformers/models/auto/tokenization_auto.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f277c7ebfabfe4..4098eb15be9f65 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -715,11 +715,11 @@ def _get_config_dict( ) try: - if from_gguf is None: + if from_gguf: + config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"] + else: # Load config dict config_dict = cls._dict_from_json_file(resolved_config_file) - else: - config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"] config_dict["_commit_hash"] = commit_hash except (json.JSONDecodeError, UnicodeDecodeError): diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 9be485089d7be7..34d699aeae15cf 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -819,14 +819,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): # If that did not work, let's try to use the config. if config_tokenizer_class is None: if not isinstance(config, PretrainedConfig): - if from_gguf is None or not from_gguf: - config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs - ) - else: + if from_gguf: gguf_path = cached_file(pretrained_model_name_or_path, from_gguf, **kwargs) config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"] config = AutoConfig.for_model(**config_dict) + else: + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) config_tokenizer_class = config.tokenizer_class if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: tokenizer_auto_map = config.auto_map["AutoTokenizer"] From 1d3acecd579dba8fba4636e1f25a9c0a697b96f6 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 30 Apr 2024 15:39:13 +0200 Subject: [PATCH 16/29] Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ef4039eb4858d8..ea91b19d535bfe 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3442,7 +3442,7 @@ def from_pretrained( resolved_archive_file = archive_file else: logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") - elif from_gguf is not None: + elif from_gguf: from .modeling_gguf_pytorch_utils import load_gguf_checkpoint # Case 1: the GGUF file is present locally From af3c42ca16b67436f48f5468f0084d62371fa52f Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 30 Apr 2024 15:40:58 +0200 Subject: [PATCH 17/29] change --- src/transformers/modeling_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ef4039eb4858d8..b8bd968270e5bf 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4100,11 +4100,9 @@ def _find_mismatched_keys( remove_prefix_from_model, ignore_mismatched_sizes, ) - - if gguf_path is None: - error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + # For GGUF models `state_dict` is never set to None as the state dict is always small - else: + if gguf_path: error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, state_dict, @@ -4122,6 +4120,9 @@ def _find_mismatched_keys( keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, ) + else: + error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + else: # This should always be a list but, just to be sure. if not isinstance(resolved_archive_file, list): From 14ad10c2672210b9bea427951c2b69b0d72ea316 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 30 Apr 2024 15:42:03 +0200 Subject: [PATCH 18/29] Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/integrations/ggml.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 68b1d67511b667..7a546a5460d38f 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -49,13 +49,11 @@ GGML_BLOCK_SIZES = { "Q8_0": 2 + 32, # Q8_0 uses a blocksize of 32 (int8 tensors) + 2 bytes allocated for the scales "Q4_K": 144, - "Q4_0": 2 - + 16, # Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales + # Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales + "Q4_0": 2 + 16, "Q6_K": 210, - "Q2_K": 256 // 16 - + 256 // 4 - + 2 - + 2, # See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9 + # See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9 + "Q2_K": 256 // 16 + 256 // 4 + 2 + 2, "Q3_K": 256 // 8 + 256 // 4 + 12 + 2, "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, } @@ -175,7 +173,7 @@ def _gguf_parse_value(_value, data_type): array_data_type = None else: if data_type[0] != 9: - raise ValueError("Received multiple types, but therefore expect the first type to indicate an array.") + raise ValueError("Received multiple types, therefore expected the first type to indicate an array.") data_type, array_data_type = data_type if data_type in [0, 1, 2, 3, 4, 5, 10, 11]: From ab621a74c13e4c9fe71a60d5c8e3ac706d5c9c35 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 30 Apr 2024 15:42:43 +0200 Subject: [PATCH 19/29] Update src/transformers/modeling_gguf_pytorch_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/modeling_gguf_pytorch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index c08b180827b6e5..f9ece551260b10 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -113,7 +113,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): for parameter in GGUF_TO_TRANSFORMERS_MAPPING: parameter_renames = GGUF_TO_TRANSFORMERS_MAPPING[parameter] if prefix in parameter_renames and config_key in parameter_renames[prefix]: - renamed_config_key = parameter_renames.get(prefix, {}).get(config_key) + renamed_config_key = parameter_renames[prefix][config_key] if renamed_config_key == -1: continue From 207820a701704876fdd7e4758a9466a446b9dd05 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 30 Apr 2024 15:44:31 +0200 Subject: [PATCH 20/29] put back comment --- src/transformers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c840dd09ec5fca..1bc58717b47273 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4100,7 +4100,7 @@ def _find_mismatched_keys( remove_prefix_from_model, ignore_mismatched_sizes, ) - + # For GGUF models `state_dict` is never set to None as the state dict is always small if gguf_path: error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( @@ -4121,6 +4121,7 @@ def _find_mismatched_keys( unexpected_keys=unexpected_keys, ) else: + # Sharded checkpoint or whole but low_cpu_mem_usage==True error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) else: From 1fef8ad04dcc47085de0b3af266f703bd39218ff Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 30 Apr 2024 15:48:31 +0200 Subject: [PATCH 21/29] add comment about mistral --- src/transformers/modeling_gguf_pytorch_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index f9ece551260b10..25d3d6a88e253e 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -87,6 +87,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): architecture = read_field(reader, "general.architecture")[0] model_name = read_field(reader, "general.name") + # in llama.cpp mistral models use the same architecture as llama. We need + # to add this patch to ensure things work correctly on our side. if "llama" in architecture and "mistral" in model_name: updated_architecture = "mistral" else: From 9ae7363c859b532c96e9462ffe9b407159ecdc09 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 30 Apr 2024 15:51:30 +0200 Subject: [PATCH 22/29] comments and added tests --- src/transformers/integrations/ggml.py | 2 +- .../modeling_gguf_pytorch_utils.py | 2 +- src/transformers/tokenization_utils_base.py | 4 ++-- tests/quantization/ggml/test_ggml.py | 18 ++++++++++++++++++ 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 7a546a5460d38f..145d57488f5a64 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -50,7 +50,7 @@ "Q8_0": 2 + 32, # Q8_0 uses a blocksize of 32 (int8 tensors) + 2 bytes allocated for the scales "Q4_K": 144, # Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales - "Q4_0": 2 + 16, + "Q4_0": 2 + 16, "Q6_K": 210, # See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9 "Q2_K": 256 // 16 + 256 // 4 + 2 + 2, diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 25d3d6a88e253e..1511fbac0976ac 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -87,7 +87,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): architecture = read_field(reader, "general.architecture")[0] model_name = read_field(reader, "general.name") - # in llama.cpp mistral models use the same architecture as llama. We need + # in llama.cpp mistral models use the same architecture as llama. We need # to add this patch to ensure things work correctly on our side. if "llama" in architecture and "mistral" in model_name: updated_architecture = "mistral" diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 60de37ba3b6798..f1095f6e780670 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2035,7 +2035,7 @@ def from_pretrained( tokenizer_config = json.load(reader) if "fast_tokenizer_files" in tokenizer_config: fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) - vocab_files["tokenizer_file"] = fast_tokenizer_file + vocab_files["tokenizer_file"] = fast_tokenizer_file # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} @@ -2127,7 +2127,7 @@ def _from_pretrained( from_slow = kwargs.get("from_slow", False) from_gguf = kwargs.get("from_gguf", False) has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None - + # If one passes a GGUF file path to `from_gguf` there is no need for this check as the tokenizer will be # loaded directly from the GGUF file. if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not from_gguf: diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 28b56e3b5bc5a4..94553a7d92fa04 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest from transformers import AutoModelForCausalLM, AutoTokenizer @@ -53,6 +54,23 @@ def test_q2_k(self): EXPECTED_TEXT = " Hello, World!\n\n[10:0" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_q2_k_serialization(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q2_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q2_k_gguf_model_id).to(torch_device) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + tokenizer.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(tmpdirname) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = " Hello, World!\n\n[10:0" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_q3_k(self): tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q3_k_gguf_model_id) model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q3_k_gguf_model_id).to(torch_device) From 55eb860c7850da9248c5a67eb4cbc55f777929e4 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 14 May 2024 10:52:43 +0200 Subject: [PATCH 23/29] fix unconsistent type --- src/transformers/tokenization_utils_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index a35abd53f36970..2884a46b92da30 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1968,7 +1968,7 @@ def from_pretrained( from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) commit_hash = kwargs.pop("_commit_hash", None) - from_gguf = kwargs.get("from_gguf", False) + from_gguf = kwargs.get("from_gguf", None) if use_auth_token is not None: warnings.warn( From f7543353a011c9c7adf8942fd6291df100c11d64 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Tue, 14 May 2024 10:54:08 +0200 Subject: [PATCH 24/29] more --- src/transformers/tokenization_utils_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 2884a46b92da30..5279d57d10f2de 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2139,7 +2139,7 @@ def _from_pretrained( # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json # file or if `from_slow` is set to True. from_slow = kwargs.get("from_slow", False) - from_gguf = kwargs.get("from_gguf", False) + from_gguf = kwargs.get("from_gguf", None) has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None # If one passes a GGUF file path to `from_gguf` there is no need for this check as the tokenizer will be From 3bdbb2e2e90389650ce93ff0081ad7c7aca9cbca Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 15 May 2024 11:28:49 +0200 Subject: [PATCH 25/29] fix tokenizer --- src/transformers/integrations/ggml.py | 1 - tests/quantization/ggml/test_ggml.py | 58 ++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 145d57488f5a64..2828b7194c776e 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -547,7 +547,6 @@ def decoder(self, replacement, add_prefix_space): decoders.Fuse(), decoders.Replace("▁", " "), ] - add_prefix_space = False if add_prefix_space: sequence += [decoders.Strip(content=" ", left=1)] return decoders.Sequence(sequence) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 94553a7d92fa04..6a0226745c1ff0 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -51,7 +51,7 @@ def test_q2_k(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello, World!\n\n[10:0" + EXPECTED_TEXT = "Hello, World!\n\n[10:0" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q2_k_serialization(self): @@ -68,7 +68,7 @@ def test_q2_k_serialization(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello, World!\n\n[10:0" + EXPECTED_TEXT = "Hello, World!\n\n[10:0" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q3_k(self): @@ -78,7 +78,7 @@ def test_q3_k(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello, World!\n\n```\n<|user" + EXPECTED_TEXT = "Hello, World!\n\n```\n<|user" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q5_k(self): @@ -88,7 +88,7 @@ def test_q5_k(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello, World!\n\nStep 3: Add" + EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q4_0(self): @@ -98,7 +98,7 @@ def test_q4_0(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello, World!\n\nStep 3: Add" + EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q4_k_m(self): @@ -108,7 +108,7 @@ def test_q4_k_m(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello, World!\n\n5. Python:\n" + EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q6_k(self): @@ -118,7 +118,7 @@ def test_q6_k(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello, World!\n\nStep 3: Add" + EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q6_k_fp16(self): @@ -132,7 +132,7 @@ def test_q6_k_fp16(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello, World!\n\nStep 3: Add" + EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q8_0(self): @@ -142,7 +142,7 @@ def test_q8_0(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello, World!\n\n5. Use a library" + EXPECTED_TEXT = "Hello, World!\n\n5. Use a library" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_mistral_q4_0(self): @@ -154,5 +154,43 @@ def test_mistral_q4_0(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = " Hello,\n\nI'm trying to create a" + EXPECTED_TEXT = "Hello,\n\nI'm trying to create a" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_tokenization_xnli(self): + import tqdm + from datasets import load_dataset + + gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q8_0_gguf_model_id) + original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id) + + dataset = load_dataset("code_x_glue_ct_code_to_text", "go") + for item in tqdm.tqdm(dataset["validation"]): + string = item["code"] + encoded1 = gguf_tokenizer.encode(string) + encoded2 = original_tokenizer.encode(string) + + self.assertEqual(encoded1, encoded2) + + decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True) + decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True) + + self.assertEqual(decoded1, decoded2) + + dataset = load_dataset("xnli", "all_languages") + + for i, item in enumerate(tqdm.tqdm(dataset["train"])): + for string in item["premise"].values(): + encoded1 = gguf_tokenizer.encode(string) + encoded2 = original_tokenizer.encode(string) + + self.assertEqual(encoded1, encoded2) + + decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True) + decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True) + + self.assertEqual(decoded1, decoded2) + + # Otherwise the test takes too long + if i > 100: + break From 0ab79f6832e3782c0e6f06bbe1cc3a7d35e4d0b7 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 15 May 2024 13:03:50 +0200 Subject: [PATCH 26/29] Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5823adf84631b6..2a87aebfe18336 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3204,7 +3204,7 @@ def from_pretrained( if from_gguf is not None and hf_quantizer is not None: raise ValueError( - "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not loaded a quantized model from the Hub." + "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub." ) if pretrained_model_name_or_path is not None and from_gguf is None: From 65433c403ae8aeaca3333c431c0f05cc0078f0cd Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 15 May 2024 13:51:33 +0200 Subject: [PATCH 27/29] address comments about tests and tokenizer + add added_tokens --- src/transformers/integrations/ggml.py | 9 +++++++++ tests/quantization/ggml/test_ggml.py | 29 ++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 2828b7194c776e..13660a0a81a362 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -513,6 +513,9 @@ def __init__(self, dict_): else: self.merges = [tuple(merge.split(" ")) for merge in self.merges] + if not hasattr(self, "added_tokens"): + self.added_tokens = [] + class GGUFLlamaConverter(LlamaConverter): def __init__(self, tokenizer_dict): @@ -539,6 +542,12 @@ def tokenizer(self, proto): AddedToken("", normalized=False, special=True), ] ) + + if len(self.proto.added_tokens) != 0: + tokenizer.add_special_tokens( + [AddedToken(added_token, normalized=False, special=False) for added_token in self.added_tokens] + ) + return tokenizer def decoder(self, replacement, add_prefix_space): diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 6a0226745c1ff0..49466d0d75e675 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -15,7 +15,7 @@ import tempfile import unittest -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device from transformers.utils import is_torch_available @@ -179,7 +179,7 @@ def test_tokenization_xnli(self): dataset = load_dataset("xnli", "all_languages") - for i, item in enumerate(tqdm.tqdm(dataset["train"])): + for i, item in enumerate(tqdm.tqdm(dataset["train"].select(range(100)))): for string in item["premise"].values(): encoded1 = gguf_tokenizer.encode(string) encoded2 = original_tokenizer.encode(string) @@ -191,6 +191,25 @@ def test_tokenization_xnli(self): self.assertEqual(decoded1, decoded2) - # Otherwise the test takes too long - if i > 100: - break + # With special tokens + gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q8_0_gguf_model_id) + original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id) + + gguf_tokenizer.add_special_tokens( + {"additional_special_tokens": [AddedToken("", rstrip=False, lstrip=False)]} + ) + original_tokenizer.add_special_tokens( + {"additional_special_tokens": [AddedToken("", rstrip=False, lstrip=False)]} + ) + + text = "Hello . Hello" + + encoded1 = gguf_tokenizer.encode(text) + encoded2 = original_tokenizer.encode(text) + + self.assertEqual(encoded1, encoded2) + + decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True) + decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True) + + self.assertEqual(decoded1, decoded2) From 1b5ae547026620d198f90ef181b078f51b9aeac7 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 15 May 2024 14:07:50 +0200 Subject: [PATCH 28/29] from_gguf -> gguf_file --- src/transformers/configuration_utils.py | 8 ++-- src/transformers/modeling_utils.py | 14 +++--- .../models/auto/tokenization_auto.py | 6 +-- src/transformers/tokenization_utils_base.py | 18 ++++---- src/transformers/tokenization_utils_fast.py | 4 +- tests/quantization/ggml/test_ggml.py | 44 +++++++++---------- 6 files changed, 47 insertions(+), 47 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 4098eb15be9f65..fc7f782e348e8d 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -659,7 +659,7 @@ def _get_config_dict( from_auto_class = kwargs.pop("_from_auto", False) commit_hash = kwargs.pop("_commit_hash", None) - from_gguf = kwargs.get("from_gguf", None) + gguf_file = kwargs.get("gguf_file", None) if trust_remote_code is True: logger.warning( @@ -679,10 +679,10 @@ def _get_config_dict( resolved_config_file = pretrained_model_name_or_path is_local = True elif is_remote_url(pretrained_model_name_or_path): - configuration_file = pretrained_model_name_or_path if from_gguf is None else from_gguf + configuration_file = pretrained_model_name_or_path if gguf_file is None else gguf_file resolved_config_file = download_url(pretrained_model_name_or_path) else: - configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if from_gguf is None else from_gguf + configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file try: # Load from local folder or from cache or download from model Hub and cache @@ -715,7 +715,7 @@ def _get_config_dict( ) try: - if from_gguf: + if gguf_file: config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"] else: # Load config dict diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2a87aebfe18336..d19f928340c1e0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2993,7 +2993,7 @@ def from_pretrained( adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) - from_gguf = kwargs.pop("from_gguf", None) + gguf_file = kwargs.pop("gguf_file", None) # Cache path to the GGUF file gguf_path = None @@ -3202,12 +3202,12 @@ def from_pretrained( keep_in_fp32_modules = None use_keep_in_fp32_modules = False - if from_gguf is not None and hf_quantizer is not None: + if gguf_file is not None and hf_quantizer is not None: raise ValueError( "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub." ) - if pretrained_model_name_or_path is not None and from_gguf is None: + if pretrained_model_name_or_path is not None and gguf_file is None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: @@ -3449,12 +3449,12 @@ def from_pretrained( resolved_archive_file = archive_file else: logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") - elif from_gguf: + elif gguf_file: from .modeling_gguf_pytorch_utils import load_gguf_checkpoint # Case 1: the GGUF file is present locally - if os.path.isfile(from_gguf): - gguf_path = from_gguf + if os.path.isfile(gguf_file): + gguf_path = gguf_file # Case 2: The GGUF path is a location on the Hub # Load from URL or cache if already cached else: @@ -3473,7 +3473,7 @@ def from_pretrained( "_commit_hash": commit_hash, } - gguf_path = cached_file(pretrained_model_name_or_path, from_gguf, **cached_file_kwargs) + gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs) state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True)["tensors"] diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 0dc44ebc4422c5..57f291aa45e8eb 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -773,7 +773,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): use_fast = kwargs.pop("use_fast", True) tokenizer_type = kwargs.pop("tokenizer_type", None) trust_remote_code = kwargs.pop("trust_remote_code", None) - from_gguf = kwargs.get("from_gguf", None) + gguf_file = kwargs.get("gguf_file", None) # First, let's see whether the tokenizer_type is passed so that we can leverage it if tokenizer_type is not None: @@ -820,8 +820,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): # If that did not work, let's try to use the config. if config_tokenizer_class is None: if not isinstance(config, PretrainedConfig): - if from_gguf: - gguf_path = cached_file(pretrained_model_name_or_path, from_gguf, **kwargs) + if gguf_file: + gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs) config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"] config = AutoConfig.for_model(**config_dict) else: diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 5279d57d10f2de..395f9859cd68ce 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1968,7 +1968,7 @@ def from_pretrained( from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) commit_hash = kwargs.pop("_commit_hash", None) - from_gguf = kwargs.get("from_gguf", None) + gguf_file = kwargs.get("gguf_file", None) if use_auth_token is not None: warnings.warn( @@ -1996,7 +1996,7 @@ def from_pretrained( is_local = os.path.isdir(pretrained_model_name_or_path) single_file_id = None if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): - if len(cls.vocab_files_names) > 1 and not from_gguf: + if len(cls.vocab_files_names) > 1 and not gguf_file: raise ValueError( f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " "supported for this tokenizer. Use a model identifier or the path to a directory instead." @@ -2011,8 +2011,8 @@ def from_pretrained( vocab_files[file_id] = pretrained_model_name_or_path single_file_id = file_id else: - if from_gguf: - vocab_files["vocab_file"] = from_gguf + if gguf_file: + vocab_files["vocab_file"] = gguf_file else: # At this point pretrained_model_name_or_path is either a directory or a model identifier name additional_files_names = { @@ -2088,9 +2088,9 @@ def from_pretrained( "files are necessary for the tokenizer to operate." ) - # If one passes a GGUF file path to `from_gguf` there is no need for this check as the tokenizer will be + # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be # loaded directly from the GGUF file. - if all(full_file_name is None for full_file_name in resolved_vocab_files.values()) and not from_gguf: + if all(full_file_name is None for full_file_name in resolved_vocab_files.values()) and not gguf_file: raise EnvironmentError( f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " @@ -2139,12 +2139,12 @@ def _from_pretrained( # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json # file or if `from_slow` is set to True. from_slow = kwargs.get("from_slow", False) - from_gguf = kwargs.get("from_gguf", None) + gguf_file = kwargs.get("gguf_file", None) has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None - # If one passes a GGUF file path to `from_gguf` there is no need for this check as the tokenizer will be + # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be # loaded directly from the GGUF file. - if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not from_gguf: + if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not gguf_file: slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained( copy.deepcopy(resolved_vocab_files), pretrained_model_name_or_path, diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index bb9709c71b304d..494c049a150bfb 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -96,7 +96,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): def __init__(self, *args, **kwargs): tokenizer_object = kwargs.pop("tokenizer_object", None) slow_tokenizer = kwargs.pop("__slow_tokenizer", None) - from_gguf = kwargs.pop("from_gguf", None) + gguf_file = kwargs.pop("gguf_file", None) fast_tokenizer_file = kwargs.pop("tokenizer_file", None) from_slow = kwargs.pop("from_slow", False) added_tokens_decoder = kwargs.pop("added_tokens_decoder", {}) @@ -115,7 +115,7 @@ def __init__(self, *args, **kwargs): elif slow_tokenizer is not None: # We need to convert a slow tokenizer to build the backend fast_tokenizer = convert_slow_tokenizer(slow_tokenizer) - elif from_gguf is not None: + elif gguf_file is not None: # We need to convert a slow tokenizer to build the backend tokenizer_dict = load_gguf_checkpoint(kwargs.get("vocab_file"))["tokenizer"] fast_tokenizer = convert_gguf_tokenizer(tokenizer_dict) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 49466d0d75e675..09a1ea51d227ec 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -45,8 +45,8 @@ class GgufIntegrationTests(unittest.TestCase): example_text = "Hello" def test_q2_k(self): - tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q2_k_gguf_model_id) - model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q2_k_gguf_model_id).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id).to(torch_device) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) @@ -55,8 +55,8 @@ def test_q2_k(self): self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q2_k_serialization(self): - tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q2_k_gguf_model_id) - model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q2_k_gguf_model_id).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id).to(torch_device) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) @@ -72,8 +72,8 @@ def test_q2_k_serialization(self): self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q3_k(self): - tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q3_k_gguf_model_id) - model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q3_k_gguf_model_id).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q3_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q3_k_gguf_model_id).to(torch_device) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) @@ -82,8 +82,8 @@ def test_q3_k(self): self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q5_k(self): - tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q5_k_gguf_model_id) - model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q5_k_gguf_model_id).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id).to(torch_device) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) @@ -92,8 +92,8 @@ def test_q5_k(self): self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q4_0(self): - tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q4_0_gguf_model_id) - model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q4_0_gguf_model_id).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q4_0_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q4_0_gguf_model_id).to(torch_device) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) @@ -102,8 +102,8 @@ def test_q4_0(self): self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q4_k_m(self): - tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q4_k_gguf_model_id) - model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q4_k_gguf_model_id).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q4_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q4_k_gguf_model_id).to(torch_device) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) @@ -112,8 +112,8 @@ def test_q4_k_m(self): self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q6_k(self): - tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q6_k_gguf_model_id) - model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q6_k_gguf_model_id).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id).to(torch_device) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) @@ -122,9 +122,9 @@ def test_q6_k(self): self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q6_k_fp16(self): - tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q6_k_gguf_model_id) + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id) model = AutoModelForCausalLM.from_pretrained( - self.model_id, from_gguf=self.q6_k_gguf_model_id, torch_dtype=torch.float16 + self.model_id, gguf_file=self.q6_k_gguf_model_id, torch_dtype=torch.float16 ).to(torch_device) self.assertTrue(model.lm_head.weight.dtype == torch.float16) @@ -136,8 +136,8 @@ def test_q6_k_fp16(self): self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_q8_0(self): - tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q8_0_gguf_model_id) - model = AutoModelForCausalLM.from_pretrained(self.model_id, from_gguf=self.q8_0_gguf_model_id).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id).to(torch_device) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) @@ -146,9 +146,9 @@ def test_q8_0(self): self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_mistral_q4_0(self): - tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, from_gguf=self.q4_0_mistral_model_id) + tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id) model = AutoModelForCausalLM.from_pretrained( - self.mistral_model_id, from_gguf=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16 + self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16 ) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) @@ -161,7 +161,7 @@ def test_tokenization_xnli(self): import tqdm from datasets import load_dataset - gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q8_0_gguf_model_id) + gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id) original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id) dataset = load_dataset("code_x_glue_ct_code_to_text", "go") @@ -192,7 +192,7 @@ def test_tokenization_xnli(self): self.assertEqual(decoded1, decoded2) # With special tokens - gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_gguf=self.q8_0_gguf_model_id) + gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id) original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id) gguf_tokenizer.add_special_tokens( From d6b67c6e10c48955630515acd6356f7aa9629f75 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 15 May 2024 14:09:17 +0200 Subject: [PATCH 29/29] replace on docs too --- docs/source/en/gguf.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index d5ec642a1904e6..db05e169edcca7 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -66,7 +66,7 @@ For now the supported model architectures are the architectures that have been v ## Example usage -In order to load `gguf` files in `transformers`, you should specify the `from_gguf` argument to the `from_pretrained` +In order to load `gguf` files in `transformers`, you should specify the `gguf_file` argument to the `from_pretrained` methods of both tokenizers and models. Here is how one would load a tokenizer and a model, which can be loaded from the exact same file: @@ -76,8 +76,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" filename = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf" -tokenizer = AutoTokenizer.from_pretrained(model_id, from_gguf=filename) -model = AutoModelForCausalLM.from_pretrained(model_id, from_gguf=filename) +tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename) +model = AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename) ``` Now you have access to the full, unquantized version of the model in the PyTorch ecosystem, where you can combine it