diff --git a/.gitignore b/.gitignore index def74a1e948e50..cf1b692e9c27c8 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,7 @@ models-mnt /lookup /main /metal +/passkey /perplexity /q8dot /quantize diff --git a/Makefile b/Makefile index 28c6d79bcd7d50..4c7e175bf6cb33 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ BUILD_TARGETS = \ main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ - speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o + speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o # Binaries only useful for tests TEST_TARGETS = \ @@ -665,6 +665,9 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + ifdef LLAMA_METAL metal: examples/metal/metal.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) diff --git a/Package.swift b/Package.swift index e33a4ff46cb156..583e2e276e471f 100644 --- a/Package.swift +++ b/Package.swift @@ -21,7 +21,7 @@ let package = Package( name: "llama", dependencies: ["ggml"], path: ".", - exclude: [], + exclude: ["ggml-metal.metal"], sources: [ "llama.cpp", ], diff --git a/README.md b/README.md index ca6d14e175b099..866aa87b4ffc58 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ ### Hot topics +- New SOTA quantized models, including pure 2-bits: https://huggingface.co/ikawrakow - Collecting Apple Silicon performance stats: - M-series: https://github.com/ggerganov/llama.cpp/discussions/4167 - A-series: https://github.com/ggerganov/llama.cpp/discussions/4508 @@ -118,6 +119,7 @@ as the main playground for developing new features for the [ggml](https://github - Python: [abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python) - Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) - Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp) +- JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp) - Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb) - Rust: [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp) - C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp) @@ -135,6 +137,7 @@ as the main playground for developing new features for the [ggml](https://github - [semperai/amica](https://github.com/semperai/amica) - [psugihara/FreeChat](https://github.com/psugihara/FreeChat) - [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal) +- [iohub/collama](https://github.com/iohub/coLLaMA) --- diff --git a/common/common.cpp b/common/common.cpp index 58155f7f5fe81a..4a6241fb569c71 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -220,6 +220,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_ctx = std::stoi(argv[i]); + } else if (arg == "--grp-attn-n" || arg == "-gan") { + if (++i >= argc) { + invalid_param = true; + break; + } + + params.grp_attn_n = std::stoi(argv[i]); + } else if (arg == "--grp-attn-w" || arg == "-gaw") { + if (++i >= argc) { + invalid_param = true; + break; + } + + params.grp_attn_w = std::stoi(argv[i]); } else if (arg == "--rope-freq-base") { if (++i >= argc) { invalid_param = true; @@ -916,6 +930,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu); #endif + printf(" -gan N, --grp-attn-n N\n"); + printf(" group-attention factor (default: %d)\n", params.grp_attn_n); + printf(" -gaw N, --grp-attn-w N\n"); + printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w); printf(" --verbose-prompt print prompt before generation\n"); printf(" -dkvc, --dump-kv-cache\n"); printf(" verbose print of the KV cache\n"); diff --git a/common/common.h b/common/common.h index 758722c11f6d36..5152b36d392ae3 100644 --- a/common/common.h +++ b/common/common.h @@ -63,6 +63,8 @@ struct gpt_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_beams = 0; // if non-zero then use beam search of given width. + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor diff --git a/convert.py b/convert.py index c3f3fc0a1fcd39..3b613eefc6c2c0 100755 --- a/convert.py +++ b/convert.py @@ -17,29 +17,58 @@ import struct import sys import time +import warnings import zipfile from abc import ABCMeta, abstractmethod -from collections import OrderedDict +from argparse import ArgumentParser from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, TypeVar, cast +from typing import ( + IO, + TYPE_CHECKING, + Any, + Callable, + Iterable, + Literal, + Optional, + Tuple, + TypeVar, +) import numpy as np from sentencepiece import SentencePieceProcessor -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) -import gguf - -if TYPE_CHECKING: - from typing import TypeAlias - -if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): +try: + from transformers import AutoTokenizer +except ModuleNotFoundError as e: + warnings.warn(f"Could not import AutoTokenizer from transformers: {e}") + +# If NO_LOCAL_GGUF is not set, try to import gguf from the local gguf-py directory +if "NO_LOCAL_GGUF" not in os.environ: + # Use absolute path to the gguf-py directory + gguf_py_dir = str(Path(__file__).resolve().parent / "gguf-py") + print(gguf_py_dir) # NOTE: Remove this once path is verified after changes are completed + if gguf_py_dir not in sys.path: + sys.path.insert(1, gguf_py_dir) + +# Import gguf module +try: + import gguf +except ModuleNotFoundError as e: + print(f"Could not import gguf: {e}") + sys.exit(1) + +if TYPE_CHECKING: # NOTE: This isn't necessary. + from typing import TypeAlias # This can technically be omitted. + +if hasattr(faulthandler, "register") and hasattr(signal, "SIGUSR1"): faulthandler.register(signal.SIGUSR1) -NDArray: TypeAlias = 'np.ndarray[Any, Any]' +# NOTE: n-dimensional arrays should be directly referenced +NDArray: TypeAlias = "np.ndarray[Any, Any]" +# Why is this here? LLAMA and GPT are technically the only compatible ARCHs. ARCH = gguf.MODEL_ARCH.LLAMA DEFAULT_CONCURRENCY = 8 @@ -49,6 +78,7 @@ # +# TODO: Clean up and refactor data types @dataclass(frozen=True) class DataType: name: str @@ -153,65 +183,85 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: @dataclass class Params: - n_vocab: int - n_embd: int - n_layer: int - n_ctx: int - n_ff: int - n_head: int - n_head_kv: int - n_experts: int | None = None - n_experts_used: int | None = None - f_norm_eps: float | None = None - - rope_scaling_type: gguf.RopeScalingType | None = None - f_rope_freq_base: float | None = None - f_rope_scale: float | None = None - n_orig_ctx: int | None = None - rope_finetuned: bool | None = None - - ftype: GGMLFileType | None = None + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int + f_norm_eps: Optional[float] = None + n_experts: Optional[int] = None + n_experts_used: Optional[int] = None + + rope_scaling_type: Optional[gguf.RopeScalingType] = None + f_rope_freq_base: Optional[float] = None + f_rope_scale: Optional[float] = None + n_orig_ctx: Optional[int] = None + rope_finetuned: Optional[bool] = None + + ftype: Optional[GGMLFileType] = None # path to the directory containing the model files - path_model: Path | None = None + path_model: Optional[Path] = None @staticmethod - def guessed(model: LazyModel) -> Params: + def guessed(model: LazyModel) -> "Params": # try transformer naming first - n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape + n_vocab, n_embd = ( + model["model.embed_tokens.weight"].shape + if "model.embed_tokens.weight" in model + else model["tok_embeddings.weight"].shape + ) # try transformer naming first if "model.layers.0.self_attn.q_proj.weight" in model: - n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) - elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming - n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model) + n_layer = next( + i + for i in itertools.count() + if f"model.layers.{i}.self_attn.q_proj.weight" not in model + ) + elif ( + "model.layers.0.self_attn.W_pack.weight" in model + ): # next: try baichuan naming + n_layer = next( + i + for i in itertools.count() + if f"model.layers.{i}.self_attn.W_pack.weight" not in model + ) else: - n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) + n_layer = next( + i + for i in itertools.count() + if f"layers.{i}.attention.wq.weight" not in model + ) if n_layer < 1: - raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + raise Exception( + "failed to guess 'n_layer'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files." + ) - n_head = n_embd // 128 # guessed - n_mult = 256 # guessed + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed # TODO: verify this n_ff = int(2 * (4 * n_embd) / 3) n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) return Params( - n_vocab = n_vocab, - n_embd = n_embd, - n_layer = n_layer, - n_ctx = -1, - n_ff = n_ff, - n_head = n_head, - n_head_kv = n_head, - f_norm_eps = 1e-5, + n_vocab=n_vocab, + n_embd=n_embd, + n_layer=n_layer, + n_ctx=-1, + n_ff=n_ff, + n_head=n_head, + n_head_kv=n_head, + f_norm_eps=1e-5, ) @staticmethod - def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: + def load_transformers_config(model: LazyModel, config_path: Path) -> "Params": config = json.load(open(config_path)) rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None @@ -224,20 +274,22 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: rope_scaling_type = gguf.RopeScalingType.LINEAR elif typ == "yarn": rope_scaling_type = gguf.RopeScalingType.YARN - n_orig_ctx = rope_scaling['original_max_position_embeddings'] - rope_finetuned = rope_scaling['finetuned'] + n_orig_ctx = rope_scaling["original_max_position_embeddings"] + rope_finetuned = rope_scaling["finetuned"] else: - raise NotImplementedError(f'Unknown rope scaling type: {typ}') + raise NotImplementedError(f"Unknown rope scaling type: {typ}") if "max_sequence_length" in config: n_ctx = config["max_sequence_length"] elif "max_position_embeddings" in config: n_ctx = config["max_position_embeddings"] else: - raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + raise Exception( + "failed to guess 'n_ctx'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files." + ) - n_experts = None + n_experts = None n_experts_used = None if "num_local_experts" in config: @@ -245,30 +297,30 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: n_experts_used = config["num_experts_per_tok"] return Params( - n_vocab = config["vocab_size"], - n_embd = config["hidden_size"], - n_layer = config["num_hidden_layers"], - n_ctx = n_ctx, - n_ff = config["intermediate_size"], - n_head = (n_head := config["num_attention_heads"]), - n_head_kv = config.get("num_key_value_heads", n_head), - n_experts = n_experts, - n_experts_used = n_experts_used, - f_norm_eps = config["rms_norm_eps"], - f_rope_freq_base = config.get("rope_theta"), - rope_scaling_type = rope_scaling_type, - f_rope_scale = f_rope_scale, - n_orig_ctx = n_orig_ctx, - rope_finetuned = rope_finetuned, + n_vocab=config["vocab_size"], + n_embd=config["hidden_size"], + n_layer=config["num_hidden_layers"], + n_ctx=n_ctx, + n_ff=config["intermediate_size"], + n_head=(n_head := config["num_attention_heads"]), + n_head_kv=config.get("num_key_value_heads", n_head), + n_experts=n_experts, + n_experts_used=n_experts_used, + f_norm_eps=config["rms_norm_eps"], + f_rope_freq_base=config.get("rope_theta"), + rope_scaling_type=rope_scaling_type, + f_rope_scale=f_rope_scale, + n_orig_ctx=n_orig_ctx, + rope_finetuned=rope_finetuned, ) # LLaMA v2 70B params.json # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} @staticmethod - def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: + def load_torch_params(model: LazyModel, config_path: Path) -> "Params": config = json.load(open(config_path)) - n_experts = None + n_experts = None n_experts_used = None f_rope_freq_base = None @@ -291,129 +343,249 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: if config.get("moe"): n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0] - n_experts = config["moe"]["num_experts"] + n_experts = config["moe"]["num_experts"] n_experts_used = config["moe"]["num_experts_per_tok"] f_rope_freq_base = 1e6 return Params( - n_vocab = model["tok_embeddings.weight"].shape[0], - n_embd = config["dim"], - n_layer = config["n_layers"], - n_ctx = n_ctx, - n_ff = n_ff, - n_head = (n_head := config["n_heads"]), - n_head_kv = config.get("n_kv_heads", n_head), - n_experts = n_experts, - n_experts_used = n_experts_used, - f_norm_eps = config["norm_eps"], - f_rope_freq_base = config.get("rope_theta", f_rope_freq_base), + n_vocab=config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), + n_embd=config["dim"], + n_layer=config["n_layers"], + n_ctx=n_ctx, + n_ff=n_ff, + n_head=(n_head := config["n_heads"]), + n_head_kv=config.get("n_kv_heads", n_head), + n_experts=n_experts, + n_experts_used=n_experts_used, + f_norm_eps=config["norm_eps"], + f_rope_freq_base=config.get("rope_theta", f_rope_freq_base), ) @staticmethod - def load(model_plus: ModelPlus) -> Params: - hf_config_path = model_plus.paths[0].parent / "config.json" + def load(model_plus: ModelPlus) -> "Params": + hf_config_path = model_plus.paths[0].parent / "config.json" orig_config_path = model_plus.paths[0].parent / "params.json" if hf_config_path.exists(): - params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) + params = Params.load_transformers_config(model_plus.model, hf_config_path) elif orig_config_path.exists(): - params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) - elif model_plus.format != 'none': + params = Params.load_torch_params(model_plus.model, orig_config_path) + elif model_plus.format != "none": params = Params.guessed(model_plus.model) else: - raise ValueError('Cannot guess params when model format is none') + raise ValueError("Cannot guess params when model format is none") params.path_model = model_plus.paths[0].parent return params -class VocabLoader: - def __init__(self, params: Params, fname_tokenizer: Path) -> None: - try: - from transformers import AutoTokenizer - except ImportError as e: - raise ImportError( - "To use VocabLoader, please install the `transformers` package. " - "You can install it with `pip install transformers`." - ) from e +class BpeVocab: # GPT + def __init__( + self, fname_tokenizer: Path, fname_added_tokens: Optional[Path] + ) -> None: + self.bpe_tokenizer = json.loads( + open(str(fname_tokenizer), encoding="utf-8").read() + ) + added_tokens: dict[str, int] + if fname_added_tokens is not None: + # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. + added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) + else: + # Fall back to trying to find the added tokens in tokenizer.json + tokenizer_json_file = fname_tokenizer.parent / "tokenizer.json" + if not tokenizer_json_file.is_file(): + added_tokens = {} + else: + tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8")) + added_tokens = dict( + (item["content"], item["id"]) + for item in tokenizer_json.get("added_tokens", []) + # Added tokens here can be duplicates of the main vocabulary. + if item["content"] not in self.bpe_tokenizer + ) + + vocab_size: int = len(self.bpe_tokenizer) + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) + if expected_ids != actual_ids: + expected_end_id = vocab_size + len(actual_ids) - 1 + raise Exception( + f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}" + ) + + items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) + self.added_tokens_list = [text for (text, idx) in items] + self.vocab_size_base: int = vocab_size + self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens + + def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + tokenizer = self.bpe_tokenizer + reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.items()} + + for i, _ in enumerate(tokenizer): + yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL - try: - self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True) - except ValueError: - self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), use_fast=False, trust_remote_code=True) + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + score = -1000.0 + yield text.encode("utf-8"), score, gguf.TokenType.CONTROL + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.bpe_tokens() + yield from self.added_tokens() - self.added_tokens_dict: OrderedDict[str, int] = OrderedDict() + def __repr__(self) -> str: + return f"" - for tok, tokidx in sorted(self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]): - if tokidx >= params.n_vocab or tokidx < self.tokenizer.vocab_size: - continue - self.added_tokens_dict[tok] = tokidx +class SentencePieceVocab: # LlaMa + def __init__( + self, fname_tokenizer: Path, fname_added_tokens: Optional[Path] + ) -> None: + self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) + added_tokens: dict[str, int] + if fname_added_tokens is not None: + added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) + else: + added_tokens = {} + + vocab_size: int = self.sentencepiece_tokenizer.vocab_size() + + new_tokens = { + id: piece for piece, id in added_tokens.items() if id >= vocab_size + } + expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) + actual_new_ids = sorted(new_tokens.keys()) + + if expected_new_ids != actual_new_ids: + raise ValueError( + f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}" + ) + + # Token pieces that were added to the base vocabulary. + self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens + + def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + tokenizer = self.sentencepiece_tokenizer + for i in range(tokenizer.vocab_size()): + piece = tokenizer.id_to_piece(i) + text: bytes = piece.encode("utf-8") + score: float = tokenizer.get_score(i) + + toktype = gguf.TokenType.NORMAL + if tokenizer.is_unknown(i): + toktype = gguf.TokenType.UNKNOWN + if tokenizer.is_control(i): + toktype = gguf.TokenType.CONTROL + + # NOTE: I think added_tokens are user defined. + # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto + # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED + + if tokenizer.is_unused(i): + toktype = gguf.TokenType.UNUSED + if tokenizer.is_byte(i): + toktype = gguf.TokenType.BYTE + + yield text, score, toktype - self.unk_token_id: int = self.tokenizer.unk_token_id - self.specials: dict[str, int] = { + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + score = -1000.0 + yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.sentencepiece_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + +class HfVocab: + def __init__( + self, + fname_tokenizer: Path, + fname_added_tokens: Optional[Path] = None, + ) -> None: + print("fname_tokenizer:", fname_tokenizer) + # Allow the tokenizer to default to slow or fast versions. + # Explicitly set tokenizer to use local paths. + self.tokenizer = AutoTokenizer.from_pretrained( + fname_tokenizer, + cache_dir=fname_tokenizer, + local_files_only=True, + ) + + # Initialize lists and dictionaries for added tokens + self.added_tokens_list = [] + self.added_tokens_dict = dict() + self.added_tokens_ids = set() + + # Process added tokens + for tok, tokidx in sorted( + self.tokenizer.get_added_vocab().items(), key=lambda x: x[1] + ): + # Only consider added tokens that are not in the base vocabulary + if tokidx >= self.tokenizer.vocab_size: + self.added_tokens_list.append(tok) + self.added_tokens_dict[tok] = tokidx + self.added_tokens_ids.add(tokidx) + + # Store special tokens and their IDs + self.specials = { tok: self.tokenizer.get_vocab()[tok] for tok in self.tokenizer.all_special_tokens } - self.special_ids: set[int] = set(self.tokenizer.all_special_ids) - self.reverse_vocab = {id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()} - self.vocab_size_base: int = self.tokenizer.vocab_size - self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict) - self.fname_tokenizer: Path = fname_tokenizer - - vocab_file = "tokenizer.model" - path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file) - if path_candidate is not None: - self.spm = SentencePieceProcessor(str(path_candidate)) - print(self.spm.vocab_size(), self.vocab_size_base) - else: - self.spm = None + self.special_ids = set(self.tokenizer.all_special_ids) - def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: - added_tokens_ids = set(self.added_tokens_dict.values()) + # Set vocabulary sizes + self.vocab_size_base = self.tokenizer.vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - for i in range(self.vocab_size_base): - if i in added_tokens_ids: - continue + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens - text = self.reverse_vocab[i].encode("utf-8") - yield text, self.get_token_score(i), self.get_token_type(i) + def hf_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + reverse_vocab = { + id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() + } - def get_token_type(self, token_id: int) -> gguf.TokenType: - toktype = gguf.TokenType.NORMAL + for token_id in range(self.vocab_size_base): + # Skip processing added tokens here + if token_id in self.added_tokens_ids: + continue - if self.spm is not None and token_id < self.spm.vocab_size(): - if self.spm.is_unknown(token_id): - toktype = gguf.TokenType.UNKNOWN - if self.spm.is_control(token_id): - toktype = gguf.TokenType.CONTROL - if self.spm.is_unused(token_id): - toktype = gguf.TokenType.UNUSED - if self.spm.is_byte(token_id): - toktype = gguf.TokenType.BYTE - else: - token = self.reverse_vocab[token_id] - if token_id == self.unk_token_id: - toktype = gguf.TokenType.UNKNOWN - elif token_id in self.special_ids: - toktype = gguf.TokenType.CONTROL - elif len(token) == 6 and token.startswith("<0x") and token.endswith(">"): - toktype = gguf.TokenType.BYTE + # Convert token text to bytes + token_text = reverse_vocab[token_id].encode("utf-8") + + # Yield token text, score, and type + yield token_text, self.get_token_score(token_id), self.get_token_type( + token_id, self.special_ids # Reuse already stored special IDs + ) - return toktype + def get_token_type(self, token_id: int, special_ids: set) -> gguf.TokenType: + # Determine token type based on whether it's a special token + return ( + gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL + ) def get_token_score(self, token_id: int) -> float: - if self.spm is not None and token_id < self.spm.vocab_size(): - return cast(float, self.spm.get_score(token_id)) - return 0.0 + # Placeholder for actual logic to determine the token's score + # This needs to be implemented based on specific requirements + return -1000.0 # Default score def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: - - for text in self.added_tokens_dict: + for text in self.added_tokens_list: if text in self.specials: - - toktype = self.get_token_type(self.specials[text]) + toktype = self.get_token_type(self.specials[text], self.special_ids) score = self.get_token_score(self.specials[text]) else: @@ -422,45 +594,18 @@ def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield text.encode("utf-8"), score, toktype - def has_newline_token(self) -> bool: - return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab + def has_newline_token(self): + return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.hf_tokens() yield from self.added_tokens() - def get_vocab_type(self) -> str: - path_candidates = [] - vocab_file = "tokenizer.model" - path_candidates.append(vocab_file) - path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file) - if path_candidate is not None: - return "llama" - - vocab_file = "vocab.json" - path_candidates.append(vocab_file) - path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file) - if path_candidate is not None: - return "gpt2" - - vocab_file = "tokenizer.json" - path_candidates.append(vocab_file) - path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file) - if path_candidate: - if not self.has_newline_token(): - return "gpt2" - return "llama" - - raise FileNotFoundError( - f"Could not find {path_candidates} in {self.fname_tokenizer} or its parent; " - "if it's in another directory, pass the directory as --vocab-dir" - ) - def __repr__(self) -> str: - return f"" + return f"" -Vocab: TypeAlias = 'VocabLoader' +Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab" # @@ -724,13 +869,17 @@ def rebuild_from_type_v2(func, new_type, args, state): CLASSES: dict[tuple[str, str], Any] = { # getattr used here as a workaround for mypy not being smart enough to determine # the staticmethods have a __func__ attribute. - ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), - ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), - ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), - ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), - ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), - ('torch', 'IntStorage'): LazyStorageKind(DT_I32), - ('torch', 'Tensor'): LazyTensor, + ("torch._tensor", "_rebuild_from_type_v2"): getattr( + rebuild_from_type_v2, "__func__" + ), + ("torch._utils", "_rebuild_tensor_v2"): getattr( + lazy_rebuild_tensor_v2, "__func__" + ), + ("torch", "BFloat16Storage"): LazyStorageKind(DT_BF16), + ("torch", "HalfStorage"): LazyStorageKind(DT_F16), + ("torch", "FloatStorage"): LazyStorageKind(DT_F32), + ("torch", "IntStorage"): LazyStorageKind(DT_I32), + ("torch", "Tensor"): LazyTensor, } def find_class(self, module: str, name: str) -> Any: @@ -839,32 +988,43 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: - if params.n_vocab != vocab.vocab_size: - if params.n_vocab == vocab.vocab_size: - print("Ignoring added_tokens.json since model matches vocab size without it.") - vocab.added_tokens_dict = OrderedDict() - vocab.vocab_size = vocab.vocab_size - return - - if pad_vocab and params.n_vocab > vocab.vocab_size: - pad_count = params.n_vocab - vocab.vocab_size - print(f'Padding vocab with {pad_count} token(s) - through ') - for i in range(1, (params.n_vocab - vocab.vocab_size) + 1): - vocab.added_tokens_dict[f''] = -1 - vocab.vocab_size = params.n_vocab - return - msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}" - msg += f" has {vocab.vocab_size})." - if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20: - msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." - if vocab.vocab_size < params.n_vocab: - msg += " Possibly try using the --padvocab option." - raise Exception(msg) + # Handle special case where the model's vocab size is not set + if params.n_vocab == -1: + raise ValueError( + f"The model's vocab size is set to -1 in params.json. Please update it manually. Maybe {vocab.vocab_size}?" + ) + + # Check for a vocab size mismatch + if params.n_vocab == vocab.vocab_size: + print("Ignoring added_tokens.json since model matches vocab size without it.") + return + + if pad_vocab and params.n_vocab > vocab.vocab_size: + pad_count = params.n_vocab - vocab.vocab_size + print( + f"Padding vocab with {pad_count} token(s) - through " + ) + for i in range(1, pad_count + 1): + vocab.added_tokens_dict[f""] = -1 + vocab.vocab_size = params.n_vocab + return + + msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer} has {vocab.vocab_size})." + if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20: + msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." + if vocab.vocab_size < params.n_vocab: + msg += " Add the --pad-vocab option and try again." + + raise Exception(msg) class OutputFile: - def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: - self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) + def __init__( + self, fname_out: Path, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE + ) -> None: + self.gguf = gguf.GGUFWriter( + fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess + ) def add_meta_arch(self, params: Params) -> None: name = "LLaMA" @@ -873,16 +1033,21 @@ def add_meta_arch(self, params: Params) -> None: if params.n_ctx == 4096: name = "LLaMA v2" elif params.path_model is not None: - name = str(params.path_model.parent).split('/')[-1] + name = str(params.path_model.parent).split("/")[-1] - self.gguf.add_name (name) - self.gguf.add_context_length (params.n_ctx) - self.gguf.add_embedding_length (params.n_embd) - self.gguf.add_block_count (params.n_layer) - self.gguf.add_feed_forward_length (params.n_ff) + self.gguf.add_name(name) + self.gguf.add_context_length(params.n_ctx) + self.gguf.add_embedding_length(params.n_embd) + self.gguf.add_block_count(params.n_layer) + self.gguf.add_feed_forward_length(params.n_ff) self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) - self.gguf.add_head_count (params.n_head) - self.gguf.add_head_count_kv (params.n_head_kv) + self.gguf.add_head_count(params.n_head) + self.gguf.add_head_count_kv(params.n_head_kv) + + if params.f_norm_eps is None: + raise ValueError("f_norm_eps is None") + + self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) if params.n_experts: self.gguf.add_expert_count(params.n_experts) @@ -890,11 +1055,6 @@ def add_meta_arch(self, params: Params) -> None: if params.n_experts_used: self.gguf.add_expert_used_count(params.n_experts_used) - if params.f_norm_eps: - self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) - else: - raise ValueError('f_norm_eps is None') - if params.f_rope_freq_base is not None: self.gguf.add_rope_freq_base(params.f_rope_freq_base) @@ -912,18 +1072,44 @@ def add_meta_arch(self, params: Params) -> None: if params.ftype is not None: self.gguf.add_file_type(params.ftype) - def add_meta_vocab(self, vocab: Vocab) -> None: + def handle_tokenizer_model(self, vocab: Vocab) -> str: + # Map the vocab types to the supported tokenizer models + tokenizer_model = { + SentencePieceVocab: "llama", + HfVocab: "llama", + BpeVocab: "gpt2", + }.get(type(vocab)) + + # Block if vocab type is not predefined + if tokenizer_model is None: + raise ValueError("Unknown vocab type: Not supported") + + return tokenizer_model + + def extract_vocabulary_from_model(self, vocab: Vocab) -> Tuple[list, list, list]: tokens = [] scores = [] toktypes = [] + # NOTE: `all_tokens` returns the base vocabulary and added tokens for text, score, toktype in vocab.all_tokens(): tokens.append(text) scores.append(score) toktypes.append(toktype) - vocab_type = vocab.get_vocab_type() - self.gguf.add_tokenizer_model(vocab_type) + return tokens, scores, toktypes + + def add_meta_vocab(self, vocab: Vocab) -> None: + # Handle the tokenizer model + tokenizer_model = self.handle_tokenizer_model(vocab) + + # Ensure that tokenizer_model is added to the GGUF model + self.gguf.add_tokenizer_model(tokenizer_model) + + # Extract model vocabulary for model conversion + tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab) + + # Add extracted token information for model conversion self.gguf.add_token_list(tokens) self.gguf.add_token_scores(scores) self.gguf.add_token_types(toktypes) @@ -933,10 +1119,14 @@ def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: n_elements = int(np.prod(tensor.shape)) - raw_dtype = getattr(tensor.data_type, 'ggml_type', None) - data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + raw_dtype = getattr(tensor.data_type, "ggml_type", None) + data_type = ( + getattr(tensor.data_type, "quantized_type", None) or tensor.data_type.dtype + ) data_nbytes = tensor.data_type.elements_to_bytes(n_elements) - self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype) + self.gguf.add_tensor_info( + name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype + ) def write_meta(self) -> None: self.gguf.write_header_to_file() @@ -950,11 +1140,14 @@ def close(self) -> None: @staticmethod def write_vocab_only( - fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, + fname_out: Path, + params: Params, + vocab: Vocab, + svocab: gguf.SpecialVocab, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: - check_vocab_size(params, vocab, pad_vocab = pad_vocab) + check_vocab_size(params, vocab, pad_vocab=pad_vocab) of = OutputFile(fname_out, endianess=endianess) @@ -982,12 +1175,17 @@ def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: @staticmethod def write_all( - fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, + fname_out: Path, + ftype: GGMLFileType, + params: Params, + model: LazyModel, + vocab: Vocab, + svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: - check_vocab_size(params, vocab, pad_vocab = pad_vocab) + check_vocab_size(params, vocab, pad_vocab=pad_vocab) of = OutputFile(fname_out, endianess=endianess) @@ -1004,18 +1202,30 @@ def write_all( of.write_tensor_info() # tensor data - ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) + ndarrays_inner = bounded_parallel_map( + OutputFile.do_item, model.items(), concurrency=concurrency + ) if ftype == GGMLFileType.MostlyQ8_0: - ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True) + ndarrays = bounded_parallel_map( + OutputFile.maybe_do_quantize, + ndarrays_inner, + concurrency=concurrency, + max_workers=concurrency, + use_processpool_executor=True, + ) else: ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) start = time.time() - for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + for i, ((name, lazy_tensor), ndarray) in enumerate( + zip(model.items(), ndarrays) + ): elapsed = time.time() - start - size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + size = " x ".join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) - print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}") + print( + f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" + ) of.gguf.write_tensor_data(ndarray) of.close() @@ -1145,30 +1355,95 @@ def load_some_model(path: Path) -> ModelPlus: return model_plus -def find_vocab_file_path(path: Path, vocab_file: str) -> Optional[Path]: - path2 = path / vocab_file - # Use `.parent` instead of /.. to handle the symlink case better. - path3 = path.parent / vocab_file - - if path2.exists(): - return path2 - if path3.exists(): - return path3 +class VocabFactory: + def __init__(self, path: Path): + self.path = path + self.files = { + "tokenizer.model": None, + "vocab.json": None, + "tokenizer.json": None, + } + self._detect_files() + + def _detect_files(self): + for file in self.files.keys(): + file_path = self.path / file + parent_file_path = self.path.parent / file + if file_path.exists(): + self.files[file] = file_path + elif parent_file_path.exists(): + self.files[file] = parent_file_path + + def _select_file(self, vocabtype: Optional[str]) -> Path: + if vocabtype in ["spm", "bpe"]: + # For SentencePiece and BPE, return specific files as before + file_key = "tokenizer.model" if vocabtype == "spm" else "vocab.json" + if self.files[file_key]: + return self.files[file_key] + else: + raise FileNotFoundError(f"{vocabtype} {file_key} not found.") + elif vocabtype == "hfft": + # For Hugging Face Fast Tokenizer, return the directory path instead of a specific file + return self.path + else: + raise ValueError(f"Unsupported vocabulary type {vocabtype}") + + def _create_special_vocab( + self, + vocab: Vocab, + vocabtype: str, + model_parent_path: Path, + ) -> gguf.SpecialVocab: + load_merges = vocabtype == "bpe" + n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None + return gguf.SpecialVocab( + model_parent_path, + load_merges=load_merges, + special_token_types=None, # Predetermined or passed as a parameter + n_vocab=n_vocab, + ) - return None + def load_vocab( + self, vocabtype: str, model_parent_path: Path + ) -> Tuple[Vocab, gguf.SpecialVocab]: + path = self._select_file(vocabtype) + print(f"Loading vocab file '{path}', type '{vocabtype}'") + + added_tokens_path = path.parent / "added_tokens.json" + if vocabtype == "bpe": + vocab = BpeVocab( + path, added_tokens_path if added_tokens_path.exists() else None + ) + elif vocabtype == "spm": + vocab = SentencePieceVocab( + path, added_tokens_path if added_tokens_path.exists() else None + ) + elif vocabtype == "hfft": + vocab = HfVocab( + path, added_tokens_path if added_tokens_path.exists() else None + ) + else: + raise ValueError(f"Unsupported vocabulary type {vocabtype}") + special_vocab = self._create_special_vocab( + vocab, + vocabtype, + model_parent_path, + ) + return vocab, special_vocab -def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path: +def default_output_file(model_paths: list[Path], file_type: GGMLFileType) -> Path: namestr = { - GGMLFileType.AllF32: "f32", + GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", - GGMLFileType.MostlyQ8_0:"q8_0", + GGMLFileType.MostlyQ8_0: "q8_0", }[file_type] ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" if ret in model_paths: sys.stderr.write( f"Error: Default output path ({ret}) would overwrite the input. " - "Please explicitly specify a path using --outfile.\n") + "Please explicitly specify a path using --outfile.\n" + ) sys.exit(1) return ret @@ -1178,32 +1453,111 @@ def do_dump_model(model_plus: ModelPlus) -> None: print(f"model_plus.format = {model_plus.format!r}") print(f"model_plus.vocab = {model_plus.vocab!r}") for name, lazy_tensor in model_plus.model.items(): - print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") + print( + f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}" + ) -def main(args_in: list[str] | None = None) -> None: +def get_argument_parser() -> ArgumentParser: output_choices = ["f32", "f16"] if np.uint32(1) == np.uint32(1).newbyteorder("<"): # We currently only support Q8_0 output on little endian systems. output_choices.append("q8_0") - parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") - parser.add_argument("--awq-path", type=Path, help="Path to scale awq cache file", default=None) - parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") - parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") - parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") - parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") - parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") - parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") - parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") - parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") - parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) - parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") - parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") - - args = parser.parse_args(args_in) + + parser = argparse.ArgumentParser( + description="Convert a LLaMa model to a GGML compatible file" + ) + + parser.add_argument( + "model", + type=Path, + help="Directory containing the model file or the model file itself (*.pth, *.pt, *.bin)", + ) + + parser.add_argument( + "--awq-path", + type=Path, + help="Path to the Activation-aware Weight Quantization cache file", + default=None, + ) + + parser.add_argument( + "--dump", + action="store_true", + help="Display the model content without converting it", + ) + + parser.add_argument( + "--dump-single", + action="store_true", + help="Display the content of a single model file without conversion", + ) + + parser.add_argument( + "--vocab-only", + action="store_true", + help="Extract and output only the vocabulary", + ) + + parser.add_argument( + "--outtype", + choices=output_choices, + help="Output format - note: q8_0 may be very slow (default: f16 or f32 based on input)", + ) + + parser.add_argument( + "--vocab-dir", + type=Path, + help="Directory containing the tokenizer.model, if separate from the model file", + ) + + parser.add_argument( + "--vocab-type", + choices=["spm", "bpe", "hfft"], # hfft: Hugging Face Fast Tokenizer + default="spm", + help="The vocabulary format used to define the tokenizer model (default: spm)", + ) + + parser.add_argument( + "--pad-vocab", + action="store_true", + help="Add padding tokens when the model's vocabulary size exceeds the tokenizer metadata", + ) + + parser.add_argument( + "--outfile", + type=Path, + help="Specify the path for the output file (default is based on input)", + ) + + parser.add_argument( + "--ctx", type=int, help="Model training context (default is based on input)" + ) + + parser.add_argument( + "--concurrency", + type=int, + help=f"Concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", + default=DEFAULT_CONCURRENCY, + ) + + parser.add_argument( + "--big-endian", + action="store_true", + help="Indicate that the model is executed on a big-endian machine", + ) + + return parser + + +def main(argv: Optional[list[str]] = None) -> None: + parser = get_argument_parser() + args = parser.parse_args(argv) + if args.awq_path: - sys.path.insert(1, str(Path(__file__).parent / 'awq-py')) + sys.path.insert(1, str(Path(__file__).resolve().parent / "awq-py")) from awq.apply_awq import add_scale_weights + tmp_model_path = args.model / "weighted_model" if tmp_model_path.is_dir(): print(f"{tmp_model_path} exists as a weighted model.") @@ -1222,22 +1576,27 @@ def main(args_in: list[str] | None = None) -> None: if not args.vocab_only: model_plus = load_some_model(args.model) else: - model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) + model_plus = ModelPlus( + model={}, paths=[args.model / "dummy"], format="none", vocab=None + ) if args.dump: do_dump_model(model_plus) return + endianess = gguf.GGUFEndian.LITTLE - if args.bigendian: + if args.big_endian: endianess = gguf.GGUFEndian.BIG params = Params.load(model_plus) if params.n_ctx == -1: if args.ctx is None: - raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" - "Please specify one with --ctx:\n" - " - LLaMA v1: --ctx 2048\n" - " - LLaMA v2: --ctx 4096\n") + raise Exception( + "The model doesn't have a context size, and you didn't specify one with --ctx\n" + "Please specify one with --ctx:\n" + " - LLaMA v1: --ctx 2048\n" + " - LLaMA v2: --ctx 4096\n" + ) params.n_ctx = args.ctx if args.outtype: @@ -1249,47 +1608,51 @@ def main(args_in: list[str] | None = None) -> None: print(f"params = {params}") - vocab: Vocab + model_parent_path = model_plus.paths[0].parent + vocab_path = Path(args.vocab_dir or args.model or model_parent_path) + vocab_factory = VocabFactory(vocab_path) + vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path) + if args.vocab_only: if not args.outfile: raise ValueError("need --outfile if using --vocab-only") - # FIXME: Try to respect vocab_dir somehow? - vocab = VocabLoader(params, args.vocab_dir or args.model) - special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, - load_merges = True, - n_vocab = vocab.vocab_size) outfile = args.outfile - OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, - endianess = endianess, pad_vocab = args.padvocab) + OutputFile.write_vocab_only( + outfile, + params, + vocab, + special_vocab, + endianess=endianess, + pad_vocab=args.pad_vocab, + ) print(f"Wrote {outfile}") return if model_plus.vocab is not None and args.vocab_dir is None: vocab = model_plus.vocab - else: - vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent - vocab = VocabLoader(params, vocab_dir) - - # FIXME: Try to respect vocab_dir somehow? - print(f"Vocab info: {vocab}") - special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, - load_merges = True, - n_vocab = vocab.vocab_size) - - print(f"Special vocab info: {special_vocab}") - model = model_plus.model - model = convert_model_names(model, params) - ftype = pick_output_type(model, args.outtype) - model = convert_to_output_type(model, ftype) - outfile = args.outfile or default_outfile(model_plus.paths, ftype) + + model = model_plus.model + model = convert_model_names(model, params) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) + outfile = args.outfile or default_output_file(model_plus.paths, ftype) params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, - concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab) + OutputFile.write_all( + outfile, + ftype, + params, + model, + vocab, + special_vocab, + concurrency=args.concurrency, + endianess=endianess, + pad_vocab=args.pad_vocab, + ) print(f"Wrote {outfile}") -if __name__ == '__main__': - main() +if __name__ == "__main__": + main(sys.argv[1:]) # Exclude the first element (script name) from sys.argv diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4cc13d6e99ce1d..0c71cbdf72a65e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -31,6 +31,7 @@ else() add_subdirectory(quantize-stats) add_subdirectory(save-load-state) add_subdirectory(simple) + add_subdirectory(passkey) add_subdirectory(speculative) add_subdirectory(lookahead) add_subdirectory(lookup) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 22a4265df77c09..b1775e0b0e8d64 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -69,6 +69,7 @@ int main(int argc, char ** argv) { std::vector tokens_list; tokens_list = ::llama_tokenize(model, params.prompt, true); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; // initialize the context diff --git a/examples/llama.swiftui/README.md b/examples/llama.swiftui/README.md index fa68e6ed8e34db..96cf743d48202f 100644 --- a/examples/llama.swiftui/README.md +++ b/examples/llama.swiftui/README.md @@ -1,7 +1,12 @@ -# llama.swiftui +# llama.cpp/examples/llama.swiftui -Local inference of llama.cpp on an iPhone. -So far I only tested with starcoder 1B model, but it can most likely handle 7B models as well. +Local inference of llama.cpp on an iPhone. This is a sample app that can be used as a starting +point for more advanced projects. -https://github.com/bachittle/llama.cpp/assets/39804642/e290827a-4edb-4093-9642-2a5e399ec545 +For usage instructions and performance stats, check the following discussion: https://github.com/ggerganov/llama.cpp/discussions/4508 + +![image](https://github.com/ggerganov/llama.cpp/assets/1991296/2b40284f-8421-47a2-b634-74eece09a299) +Video demonstration: + +https://github.com/bachittle/llama.cpp/assets/39804642/e290827a-4edb-4093-9642-2a5e399ec545 diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 502b788b14aba2..d94795fe317d47 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -243,6 +243,9 @@ int main(int argc, char ** argv) { } auto image_embed = load_image(ctx_llava, ¶ms); + if (!image_embed) { + return 1; + } // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c096f110b32c55..5ea67051f36546 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -439,6 +439,21 @@ int main(int argc, char ** argv) { LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + + // group-attention state + // number of grouped KV tokens so far (used only if params.grp_attn_n > 1) + int ga_i = 0; + + const int ga_n = params.grp_attn_n; + const int ga_w = params.grp_attn_w; + + if (ga_n != 1) { + GGML_ASSERT(ga_n > 0 && "grp_attn_n must be positive"); // NOLINT + GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT + //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT + //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT + LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w); + } LOG_TEE("\n\n"); if (params.interactive) { @@ -500,37 +515,61 @@ int main(int argc, char ** argv) { fflush(stdout); } - // infinite text generation via context swapping - // if we run out of context: - // - take the n_keep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { - if (params.n_predict == -2) { - LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); - break; - } + if (ga_n == 1) { + // infinite text generation via context shifting + // if we run out of context: + // - take the n_keep first tokens from the original prompt (via n_past) + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { + if (params.n_predict == -2) { + LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + break; + } - const int n_left = n_past - params.n_keep - 1; - const int n_discard = n_left/2; + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", - n_past, n_left, n_ctx, params.n_keep, n_discard); + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); - n_past -= n_discard; + n_past -= n_discard; - if (ctx_guidance) { - n_past_guidance -= n_discard; + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + LOG("clear session path\n"); + path_session.clear(); } + } else { + // context extension via Self-Extend + while (n_past >= ga_i + ga_w) { + const int ib = (ga_n*ga_i)/ga_w; + const int bd = (ga_w/ga_n)*(ga_n - 1); + const int dd = (ga_w/ga_n) - ib*bd - ga_w; - LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + LOG("\n"); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd); + LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + llama_kv_cache_seq_shift(ctx, 0, ga_i, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); - LOG("clear session path\n"); - path_session.clear(); + n_past -= bd; + + ga_i += ga_w/ga_n; + + LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i); + } } // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) diff --git a/examples/passkey/CMakeLists.txt b/examples/passkey/CMakeLists.txt new file mode 100644 index 00000000000000..3161bf3ef9a455 --- /dev/null +++ b/examples/passkey/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET passkey) +add_executable(${TARGET} passkey.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/passkey/README.md b/examples/passkey/README.md new file mode 100644 index 00000000000000..4a22bb55975be5 --- /dev/null +++ b/examples/passkey/README.md @@ -0,0 +1,12 @@ +# llama.cpp/example/passkey + +See the following PRs for more info: + +- https://github.com/ggerganov/llama.cpp/pull/3856 +- https://github.com/ggerganov/llama.cpp/pull/4810 + +### Usage + +```bash +make -j && ./passkey ./models/llama-7b-v2/ggml-model-f16.gguf 250 +``` diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp new file mode 100644 index 00000000000000..5c0022832146ba --- /dev/null +++ b/examples/passkey/passkey.cpp @@ -0,0 +1,296 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +int main(int argc, char ** argv) { + gpt_params params; + + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]); + return 1 ; + } + + int seed = -1; + + int n_junk = 250; // number of times to repeat the junk text + int n_keep = 32; // number of tokens in the prompt prefix + int n_grp = 1; // if more than 1 - perform LongLM SelfExtend + int i_pos = -1; // position of the passkey in the junk text + + if (argc >= 2) { + params.model = argv[1]; + } + + if (argc >= 3) { + n_junk = std::stoi(argv[2]); + } + + if (argc >= 4) { + n_grp = std::stoi(argv[3]); + } + + if (argc >= 5) { + i_pos = std::stoi(argv[4]); + } + + if (argc >= 6) { + seed = std::stoi(argv[5]); + } + + if (seed == -1) { + seed = time(NULL); + } + + srand(seed); + + if (i_pos == -1) { + i_pos = rand() % n_junk; + } + + const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; + const std::string prompt_suffix = " What is the pass key? The pass key is"; + + // generate junk text + params.prompt = prompt_prefix; + + const int passkey = rand() % 50000 + 1; + + for (int i = 0; i < n_junk; i++) { + if (i % n_junk == i_pos) { + params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key."; + } + + params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."; + } + + params.prompt += prompt_suffix; + + // init LLM + + llama_backend_init(params.numa); + + // initialize the model + + llama_model_params model_params = llama_model_default_params(); + + model_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = seed; + ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; + ctx_params.n_batch = 512; + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + + GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // tokenize the prompt + std::vector tokens_list; + tokens_list = ::llama_tokenize(ctx, params.prompt, true); + + // tokenize the prefix and use it as a sink + const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size(); + + const int n_tokens_all = tokens_list.size(); + + // we leave a margin of 16 tokens for the generated text - it should contain just the passkey + const int n_predict = 16; + + // total length of the sequences including the prompt + const int n_len = n_tokens_all + n_predict; + + const int n_ctx = llama_n_ctx(ctx) - n_keep; + const int n_kv_req = llama_n_ctx(ctx); + const int n_batch = ctx_params.n_batch; + const int n_batch_grp = ctx_params.n_batch/n_grp; + + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch); + + // print the prompt token-by-token + + LOG_TEE("\n"); + LOG_TEE("prefix tokens: %d\n", n_tokens_prefix); + LOG_TEE("prompt tokens: %d\n", n_tokens_all); + //LOG_TEE("prompt: %s\n", params.prompt.c_str()); + + llama_batch batch = llama_batch_init(512, 0, 1); + + int n_past = 0; + + // fill the KV cache + for (int i = 0; i < n_ctx; i += n_batch) { + if (i > 0 && n_grp > 1) { + // if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp + const int ib = i/n_batch - 1; + const int bd = n_batch_grp*(n_grp - 1); + + llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + + n_past -= bd; + } + + llama_batch_clear(batch); + + for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { + llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + } + + if (i + n_batch >= n_tokens_all) { + batch.logits[batch.n_tokens - 1] = true; + } + + if (llama_decode(ctx, batch) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); + + if (i + n_batch >= n_tokens_all) { + break; + } + } + + for (int i = n_ctx; i < n_tokens_all; i += n_batch) { + const int n_discard = n_batch; + + LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + + n_past -= n_discard; + + llama_batch_clear(batch); + + for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { + llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + } + + if (i + n_batch >= n_tokens_all) { + batch.logits[batch.n_tokens - 1] = true; + } + + if (llama_decode(ctx, batch) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); + } + + { + const int n_discard = n_past - n_ctx + n_predict; + + if (n_discard > 0) { + LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + + n_past -= n_discard; + } + } + + LOG_TEE("\n"); + LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk); + LOG_TEE("\n"); + + // main loop + + int n_cur = n_tokens_all; + int n_decode = 0; + + LOG_TEE("%s", prompt_suffix.c_str()); + fflush(stdout); + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { + // sample the next token + { + auto n_vocab = llama_n_vocab(model); + auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + + std::vector candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // sample the most likely token + const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of stream? + if (new_token_id == llama_token_eos(model) || n_cur == n_len) { + LOG_TEE("\n"); + + break; + } + + LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); + fflush(stdout); + + n_decode += 1; + + // prepare the next batch + llama_batch_clear(batch); + + // push this new token for next evaluation + llama_batch_add(batch, new_token_id, n_past++, { 0 }, true); + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + LOG_TEE("\n"); + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + llama_print_timings(ctx); + + fprintf(stderr, "\n"); + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} diff --git a/examples/server/README.md b/examples/server/README.md index 243e669912cf0e..d85a14f891bc4e 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -23,6 +23,7 @@ Command line options: - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`. - `--port`: Set the port to listen. Default: `8080`. - `--path`: path from which to serve static files (default examples/server/public) +- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. - `--embedding`: Enable embedding extraction, Default: disabled. - `-np N`, `--parallel N`: Set the number of slots for process requests (default: 1) - `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled) @@ -174,35 +175,44 @@ node index.js `system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime) - *Result JSON:* +### Result JSON: - Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion. +* Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion. - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) +- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure: - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model` - - `model`: The path to the model loaded with `-m` - - `prompt`: The provided `prompt` - - `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token - - `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered - - `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided - - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) - - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` - - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) - - `tokens_evaluated`: Number of tokens evaluated in total from the prompt - - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) +``` +{ + "content": "", + "probs": [ + { + "prob": float, + "tok_str": "" + }, + { + "prob": float, + "tok_str": "" + }, + ... + ] +}, +``` +Notice that each `probs` is an array of length `n_probs`. + +- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. +- `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) +- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model` +- `model`: The path to the model loaded with `-m` +- `prompt`: The provided `prompt` +- `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token +- `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered +- `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided +- `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) +- `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` +- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) +- `tokens_evaluated`: Number of tokens evaluated in total from the prompt +- `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) - **POST** `/tokenize`: Tokenize a given text. diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 61c5f67bc68bf9..484718bb022d4a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -118,6 +118,7 @@ #endif // defined(GGML_USE_HIPBLAS) +#define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 #define CC_OFFSET_AMD 1000000 @@ -479,6 +480,14 @@ typedef struct { } block_q6_K; static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); +#define QR2_XXS 8 +#define QI2_XXS (QK_K / (4*QR2_XXS)) +typedef struct { + half d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); + #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -550,11 +559,12 @@ static std::array g_default_tensor_split = {}; struct cuda_device_capabilities { int cc; // compute capability + size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory }; -static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} }; +static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} }; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; @@ -583,6 +593,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + (void) a; + bad_arch(); +#else +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +} + static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -591,6 +614,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + (void) x; + bad_arch(); +#else +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +} + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -1290,6 +1326,128 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t #endif } +static const __device__ uint64_t kgrid_iq2xxs[256] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +}; + +static const __device__ uint8_t ksigns_iq2xs[128] = { + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +}; + +static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + +inline bool ggml_cuda_supports_mmq(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return true; + default: + return false; + } +} + +template +static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq2_xxs * x = (const block_iq2_xxs *) vx; + + const int tid = threadIdx.x; +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * aux8 = (const uint8_t *)q2; + const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); +#else + assert(false); +#endif + +} + static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -3823,6 +3981,55 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); } +static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if QK_K == 256 + const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq; + +#if QR2_XXS == 8 + const int ib32 = iqs; + const uint16_t * q2 = bq2->qs + 4*ib32; + const uint8_t * aux8 = (const uint8_t *)q2; + const int8_t * q8 = bq8_1[ib32].qs; + uint32_t aux32 = q2[2] | (q2[3] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]); + const uint8_t signs = ksigns_iq2xs[aux32 & 127]; + for (int j = 0; j < 8; ++j) { + sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + aux32 >>= 7; + } + const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f; + return d * sumi; +#else + // iqs is 0...15 + const int ib32 = iqs/2; + const int il = iqs%2; + const uint16_t * q2 = bq2->qs + 4*ib32; + const uint8_t * aux8 = (const uint8_t *)q2; + const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]); + const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f; + const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127]; + const int8_t * q8 = bq8_1[ib32].qs + 16*il; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1); + sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1); + } + return d * (sumi1 + sumi2); +#endif +#else + assert(false); + return 0.f; +#endif +} + template static __device__ __forceinline__ void mul_mat_q( @@ -5204,75 +5411,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { +template +static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL + const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; + const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; + const int tid = threadIdx.x; const int rowx = blockIdx.x; const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension - const int block_size = blockDim.x; + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + extern __shared__ half data_soft_max_f16[]; + half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication + // (shared memory) buffer to cache values between iterations: + half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data); + // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead + // in that case col_smem == col_data must be enforced to avoid race conditions + + half2 max_val = make_half2(-INFINITY, -INFINITY); + +#pragma unroll + for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { + const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id; + const int col_smem = vals_smem ? col0 + tid : col_data; + + const int ix = rowx*ncols_data + col_data; + const int iy = rowy*ncols_data + col_data; + + half2 val; + if (need_check && col_data + 0 >= ncols_data) { + val.x = -INFINITY; + } else { + val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); + } + if (need_check && col_data + WARP_SIZE >= ncols_data) { + val.y = -INFINITY; + } else { + val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); + } + if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { + vals[col_smem] = val; + } + max_val = __hmax2(max_val, val); + } + + // find the max value in the block + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = -INFINITY; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = __hmax(max_val.x, max_val.y); + } + __syncthreads(); + + max_val = __half2half2(buf_iw[lane_id]); + max_val = warp_reduce_max(max_val); + } else { + max_val = __half2half2(__hmax(max_val.x, max_val.y)); + } + + half2 tmp = make_half2(0.0f, 0.0f); // partial sums + +#pragma unroll + for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { + const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id; + + if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) { + break; + } + + const half2 val = h2exp(vals[col_smem] - max_val); + + tmp += val; + vals[col_smem] = val; + } + + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = 0.0f; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = tmp.x + tmp.y; + } + __syncthreads(); + + tmp = __half2half2(buf_iw[lane_id]); + tmp = warp_reduce_sum(tmp); + } else { + tmp = __half2half2(tmp.x + tmp.y); + } + + const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp; + +#pragma unroll + for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { + const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id; + const int col_smem = vals_smem ? col0 + tid : col_data; + + const int idst = rowx*ncols_data + col_data; + const half2 result = vals[col_smem] * inv_sum; + + if (need_check && col_data + 0 >= ncols_data) { + return; + } + dst[idst] = result.x; + + if (need_check && col_data + WARP_SIZE >= ncols_data) { + return; + } + + dst[idst + WARP_SIZE] = result.y; + } +#else + (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale; + bad_arch(); +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} + +template +static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { + const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension + + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + extern __shared__ float data_soft_max_f32[]; + float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication + // shared memory buffer to cache values between iterations: + float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols; float max_val = -INFINITY; - for (int col = tid; col < ncols; col += block_size) { +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); + + const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + vals[col] = val; + max_val = max(max_val, val); } // find the max value in the block max_val = warp_reduce_max(max_val); if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = -INFINITY; + buf_iw[lane_id] = -INFINITY; } __syncthreads(); if (lane_id == 0) { - buf[warp_id] = max_val; + buf_iw[warp_id] = max_val; } __syncthreads(); - max_val = buf[lane_id]; + max_val = buf_iw[lane_id]; max_val = warp_reduce_max(max_val); } - float tmp = 0.f; + float tmp = 0.0f; // partial sum - for (int col = tid; col < ncols; col += block_size) { - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; - const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = expf(vals[col] - max_val); tmp += val; - dst[ix] = val; + vals[col] = val; } // find the sum of exps in the block tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = 0.f; + buf_iw[lane_id] = 0.0f; } __syncthreads(); if (lane_id == 0) { - buf[warp_id] = tmp; + buf_iw[warp_id] = tmp; } __syncthreads(); - tmp = buf[lane_id]; + tmp = buf_iw[lane_id]; tmp = warp_reduce_sum(tmp); } - const float inv_tmp = 1.f / tmp; + const float inv_sum = 1.0f / tmp; - for (int col = tid; col < ncols; col += block_size) { - const int i = rowx*ncols + col; - dst[i] *= inv_tmp; +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + return; + } + + const int idst = rowx*ncols + col; + dst[idst] = vals[col] * inv_sum; } } @@ -5662,6 +6027,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu #endif } +template +static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_xxs<<>>(vx, y); +} + template static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; @@ -5690,6 +6061,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_q5_K_cuda; case GGML_TYPE_Q6_K: return dequantize_row_q6_K_cuda; + case GGML_TYPE_IQ2_XXS: + return dequantize_row_iq2_xxs_cuda; case GGML_TYPE_F32: return convert_unary_cuda; default: @@ -5719,6 +6092,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q5_K_cuda; case GGML_TYPE_Q6_K: return dequantize_row_q6_K_cuda; + case GGML_TYPE_IQ2_XXS: + return dequantize_row_iq2_xxs_cuda; case GGML_TYPE_F16: return convert_unary_cuda; default: @@ -5913,6 +6288,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } +static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + static void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -6552,12 +6936,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } +static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { + int nth = WARP_SIZE; + while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; + const dim3 block_dims(nth, 1, 1); + const dim3 block_nums(nrows_x, 1, 1); + const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half); + static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + if (shmem <= g_device_caps[g_main_device].smpb) { + switch (ncols_x) { + case 32: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 64: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 128: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 256: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 512: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 1024: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 2048: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 4096: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + default: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + } + } else { + const size_t shmem_low = WARP_SIZE*sizeof(half); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + } +} + static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + if (shmem < g_device_caps[g_main_device].smpb) { + switch (ncols_x) { + case 32: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 64: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 128: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 256: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 512: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 1024: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 2048: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 4096: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + default: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + } + } else { + const size_t shmem_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + } } static void im2col_f32_f16_cuda(const float* x, half* dst, @@ -6873,6 +7335,7 @@ void ggml_init_cublas() { #else g_device_caps[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + g_device_caps[id].smpb = prop.sharedMemPerBlock; } for (int id = 0; id < g_device_count; ++id) { g_default_tensor_split[id] /= total_vram; @@ -7382,6 +7845,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_RDNA2 ? 128 : 64; default: GGML_ASSERT(false); @@ -7402,6 +7866,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q6_K: return 64; @@ -7467,6 +7932,9 @@ static void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_Q6_K: mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; default: GGML_ASSERT(false); break; @@ -7874,7 +8342,21 @@ static void ggml_cuda_op_soft_max( float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + const bool use_f16_soft_max = false; +#else +#ifdef GGML_CUDA_F16 + const bool use_f16_soft_max = true; +#else + const bool use_f16_soft_max = false; +#endif // GGML_CUDA_F16 +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + + if (use_f16_soft_max) { + soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + } else { + soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + } (void) dst; } @@ -8703,6 +9185,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type); + // debug helpers //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); diff --git a/ggml-metal.m b/ggml-metal.m index 5e530e280f9ded..1e7ed6bb8b756d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -88,6 +88,7 @@ GGML_METAL_DECL_KERNEL(get_rows_q5_K); GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(get_rows_i32); + GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(group_norm); GGML_METAL_DECL_KERNEL(norm); @@ -106,6 +107,7 @@ GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32); GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32); //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16); GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32); @@ -121,6 +123,7 @@ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32); GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); @@ -133,6 +136,7 @@ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32); @@ -145,6 +149,7 @@ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32); GGML_METAL_DECL_KERNEL(rope_f32); GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); @@ -379,6 +384,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(get_rows_q5_K); GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(get_rows_i32); + GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(group_norm); GGML_METAL_ADD_KERNEL(norm); @@ -397,6 +403,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32); GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32); //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16); GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32); @@ -412,6 +419,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32); if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); @@ -425,6 +433,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32); @@ -437,6 +446,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32); } GGML_METAL_ADD_KERNEL(rope_f32); GGML_METAL_ADD_KERNEL(rope_f16); @@ -502,6 +512,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_q5_K); GGML_METAL_DEL_KERNEL(get_rows_q6_K); GGML_METAL_DEL_KERNEL(get_rows_i32); + GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs); GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(group_norm); GGML_METAL_DEL_KERNEL(norm); @@ -520,6 +531,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32); GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32); //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16); GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32); @@ -535,6 +547,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32); if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); @@ -548,6 +561,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32); @@ -560,6 +574,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32); } GGML_METAL_DEL_KERNEL(rope_f32); GGML_METAL_DEL_KERNEL(rope_f16); @@ -1541,6 +1556,7 @@ bool ggml_metal_graph_compute( case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; + case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1653,6 +1669,12 @@ bool ggml_metal_graph_compute( nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; } break; + case GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32]; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -1686,9 +1708,14 @@ bool ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || + //src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_IQ2_XXS) { + [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1778,6 +1805,7 @@ bool ggml_metal_graph_compute( case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break; case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break; + case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1893,6 +1921,12 @@ bool ggml_metal_graph_compute( nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32]; } break; + case GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32]; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); @@ -1942,9 +1976,14 @@ bool ggml_metal_graph_compute( if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || + //src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src2t == GGML_TYPE_IQ2_XXS) { + [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src2t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1982,6 +2021,7 @@ bool ggml_metal_graph_compute( case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break; + case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index a7d3f9efa5787a..229efb8b69db1f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2446,6 +2446,12 @@ typedef struct { } block_q6_K; // 210 bytes / block +typedef struct { + half d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +// 66 bytes / block for QK_K = 256, so 2.0625 bpw + //====================================== dot products ========================= void kernel_mul_mv_q2_K_f32_impl( @@ -3468,6 +3474,221 @@ kernel void kernel_mul_mv_q6_K_f32( kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); } +// ======================= "True" 2-bit + +constexpr constant static uint64_t kgrid_iq2xxs[256] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +}; + +constexpr constant static uint8_t ksigns_iq2xs[128] = { + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +}; + +constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + +void kernel_mul_mv_iq2_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + +#if QK_K == 256 + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xxs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float sum = 0; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d * sum; + + dh += nb*sizeof(block_iq2_xxs)/2; + q2 += nb*sizeof(block_iq2_xxs)/2; + } + + y4 += 32 * 32; + } +#else + // TODO +#endif + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xxs_f32")]] +kernel void kernel_mul_mv_iq2_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template @@ -3620,8 +3841,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); - const half ml = 4.h * dl; + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; il = (il/2) & 3; const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); @@ -3688,7 +3909,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg uint8_t ul = 1 << (il/2); il = il & 3; const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; + const float d = il < 2 ? xb->d : xb->d / 16.f; const float min = xb->dmin; const float dl = d * sc[0]; const float ml = min * sc[1]; @@ -3721,17 +3942,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg #if QK_K == 256 ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); qh = qh + 32*(il/8) + 16*(il&1); - half sc = scales[(il%2) + 2 * ((il/2))]; + float sc = scales[(il%2) + 2 * ((il/2))]; il = (il/2) & 3; #else ql = ql + 16 * (il&1); - half sc = scales[il]; + float sc = scales[il]; #endif const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const half coef = il>1 ? 1.f/16.h : 1.h; - const half ml = d_all * sc * 32.h; - const half dl = d_all * sc * coef; + const float coef = il>1 ? 1.f/16.f : 1.f; + const float ml = d_all * sc * 32.f; + const float dl = d_all * sc * coef; for (int i = 0; i < 16; ++i) { const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); @@ -3739,6 +3960,31 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg } } +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + template kernel void kernel_get_rows( device const void * src0, @@ -4278,6 +4524,7 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; // // matrix-matrix multiplication @@ -4314,6 +4561,7 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication @@ -4362,6 +4610,7 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; // // matrix-vector multiplication @@ -5134,3 +5383,68 @@ kernel void kernel_mul_mv_id_q6_K_f32( tiisg, sgitg); } + +[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] +kernel void kernel_mul_mv_id_iq2_xxs_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_iq2_xxs_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + shared_values, + tgpig, + tiisg, + sgitg); +} diff --git a/ggml-quants.c b/ggml-quants.c index 55a9496d1b37ba..d497e6de9ceb5c 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -2340,6 +2340,138 @@ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * return (n/QK_K*sizeof(block_q6_K)); } +// ====================== "True" 2-bit (de)-quantization + +void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) { + (void)x; + (void)y; + (void)k; + assert(k % QK_K == 0); + //fprintf(stderr, "=========================== %s: not implemented\n", __func__); +} + +static const uint64_t iq2xxs_grid[256] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +}; + +static const uint8_t ksigns_iq2xs[128] = { + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +}; +static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + +void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t)); + const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } + } +} + +void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_iq2_xxs * restrict y = vy; + quantize_row_iq2_xxs_reference(x, y, k); +} + +size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK_K == 0); + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K; + quantize_row_iq2_xxs_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_iq2_xxs)); +} + //===================================== Q8_K ============================================== void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { @@ -2362,7 +2494,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict x += QK_K; continue; } - const float iscale = -128.f/max; + //const float iscale = -128.f/max; + // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward + const float iscale = -127.f/max; for (int j = 0; j < QK_K; ++j) { int v = nearest_int(iscale*x[j]); y[i].qs[j] = MIN(127, v); @@ -7065,3 +7199,161 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri } #endif + +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; + +void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_iq2_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.25f * sumf; + +#elif defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} diff --git a/ggml-quants.h b/ggml-quants.h index 62c1df6cbd2748..8dd911d4182fa9 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -165,6 +165,14 @@ typedef struct { } block_q8_K; static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); +// (Almost) "true" 2-bit quantization. +// Due to the need to use blocks as per ggml dsign, it ends up using +// 2.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + ggml_fp16_t d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); // Quantization void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k); @@ -180,6 +188,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k); void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); +void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k); void quantize_row_q4_0(const float * restrict x, void * restrict y, int k); void quantize_row_q4_1(const float * restrict x, void * restrict y, int k); @@ -194,6 +203,7 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k); void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); +void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k); // Dequantization void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k); @@ -209,6 +219,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); +void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k); // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); @@ -222,3 +233,4 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); diff --git a/ggml.c b/ggml.c index 7dcc70a5843b32..1027fabdbf0e49 100644 --- a/ggml.c +++ b/ggml.c @@ -573,6 +573,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, }, + [GGML_TYPE_IQ2_XXS] = { + .type_name = "iq2_xxs", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_xxs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs, + .from_float = quantize_row_iq2_xxs, + .from_float_reference = (ggml_from_float_t) quantize_row_iq2_xxs_reference, + .vec_dot = ggml_vec_dot_iq2_xxs_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, @@ -2111,6 +2122,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -7457,6 +7469,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -7721,6 +7734,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: { ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -7835,6 +7849,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: default: { GGML_ASSERT(false); @@ -10476,6 +10491,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: { ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); } break; @@ -10650,6 +10666,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: default: { GGML_ASSERT(false); @@ -10844,6 +10861,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -11480,6 +11498,7 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -11554,6 +11573,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -18670,6 +18690,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q6_K * block = (block_q6_K*)dst + start / QK_K; result = ggml_quantize_q6_K(src + start, block, n, n, hist); } break; + case GGML_TYPE_IQ2_XXS: + { + GGML_ASSERT(start % QK_K == 0); + block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K; + result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist); + } break; case GGML_TYPE_F16: { int elemsize = sizeof(ggml_fp16_t); diff --git a/ggml.h b/ggml.h index 98191b1a3624d0..2013a73d1bd81a 100644 --- a/ggml.h +++ b/ggml.h @@ -339,6 +339,7 @@ extern "C" { GGML_TYPE_Q5_K = 13, GGML_TYPE_Q6_K = 14, GGML_TYPE_Q8_K = 15, + GGML_TYPE_IQ2_XXS = 16, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -373,6 +374,7 @@ extern "C" { GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors }; // available tensor operations: @@ -2072,6 +2074,7 @@ extern "C" { GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); diff --git a/llama.cpp b/llama.cpp index 3af2419d49dc76..ceb70025d20eb7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1936,6 +1936,28 @@ static void llama_kv_cache_seq_shift( cache.head = new_head != cache.size ? new_head : 0; } +static void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.has_shift = true; + + { + llama_pos p_old = cache.cells[i].pos; + cache.cells[i].pos /= d; + cache.cells[i].delta += cache.cells[i].pos - p_old; + } + } + } +} + // // model loading and saving // @@ -2233,6 +2255,7 @@ struct llama_model_loader { case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -2577,6 +2600,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XSS - 2.0625 bpw"; default: return "unknown, may not work"; } @@ -8490,6 +8514,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_K_S: case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS:quantized_type = GGML_TYPE_IQ2_XXS; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } @@ -9622,9 +9647,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { } void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (delta == 0) { + return; + } + llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); } +void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); +} + // Returns the *maximum* size of the state size_t llama_get_state_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. diff --git a/llama.h b/llama.h index 605783b424ec92..5dc4c93571a58b 100644 --- a/llama.h +++ b/llama.h @@ -103,6 +103,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -495,6 +496,17 @@ extern "C" { llama_pos p1, llama_pos delta); + // Integer division of the positions by factor of `d > 1` + // If the KV cache is RoPEd, the KV data is updated accordingly + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_kv_cache_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d); + // // State / sessions // diff --git a/scripts/get-pg.sh b/scripts/get-pg.sh new file mode 100755 index 00000000000000..b027793e19f7a7 --- /dev/null +++ b/scripts/get-pg.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +function usage { + echo "usage: $0" + echo "note: n is the number of essays to download" + echo "for specific n, the resulting pg.txt file will have the following number of tokens:" + echo "n | tokens" + echo "--- | ---" + echo "1 | 6230" + echo "2 | 23619" + echo "5 | 25859" + echo "10 | 36888" + echo "15 | 50188" + echo "20 | 59094" + echo "25 | 88764" + echo "30 | 103121" + echo "32 | 108338" + echo "35 | 113403" + echo "40 | 127699" + echo "45 | 135896" + exit 1 +} + +function has_cmd { + if ! [ -x "$(command -v $1)" ]; then + echo "error: $1 is not available" >&2 + exit 1 + fi +} + +# check for: curl, html2text, tail, sed, fmt +has_cmd curl +has_cmd html2text +has_cmd tail +has_cmd sed + +if [ $# -ne 1 ]; then + usage +fi + +n=$1 + +# get urls +urls="$(curl http://www.aaronsw.com/2002/feeds/pgessays.rss | grep html | sed -e "s/.*http/http/" | sed -e "s/html.*/html/" | head -n $n)" + +printf "urls:\n%s\n" "$urls" + +if [ -f pg.txt ]; then + rm pg.txt +fi + +c=1 +for url in $urls; do + echo "processing $url" + + cc=$(printf "%03d" $c) + + curl -L $url | html2text | tail -n +4 | sed -E "s/^[[:space:]]+//g" | fmt -w 80 >> pg-$cc-one.txt + cat pg-$cc-one.txt >> pg.txt + + cp -v pg.txt pg-$cc-all.txt + c=$((c+1)) + + # don't flood the server + sleep 1 +done + +echo "done. data in pg.txt" + +exit 0 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 18eb4a881b47f6..e27f9f2844c667 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -455,7 +455,7 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { - printf("[%s] NMSE = %f ", ggml_op_desc(t1), err); + printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -1463,6 +1463,7 @@ struct test_moe : public test_case { static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { std::vector> test_cases; + std::default_random_engine rng(0); const ggml_type all_types[] = { GGML_TYPE_F32, GGML_TYPE_F16, @@ -1597,7 +1598,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5)); - test_cases.emplace_back(new test_soft_max()); + std::uniform_int_distribution<> dist_ne1(1, 50); + int exponent = 1; + while (exponent < (1 << 17)) { + std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent); + + for (int n = 0; n < 10; ++n) { + int64_t ne0 = dist_ne0(rng); + int64_t ne1 = dist_ne1(rng); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1})); + } + + exponent <<= 1; + } for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index a2459a2867c5c0..cee712618be3d1 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -134,6 +134,11 @@ int main(int argc, char * argv[]) { continue; } + if ((ggml_type)i == GGML_TYPE_IQ2_XXS) { + printf("Skip %s due to missing quantization functionality\n", ggml_type_name((ggml_type) i)); + continue; + } + printf("Testing %s\n", ggml_type_name((ggml_type) i)); if (qfns.from_float && qfns.to_float) {