diff --git a/.github/ISSUE_TEMPLATE/01-bug-low.yml b/.github/ISSUE_TEMPLATE/01-bug-low.yml new file mode 100644 index 0000000000000..bfb9d9a0692c4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/01-bug-low.yml @@ -0,0 +1,50 @@ +name: Low Severity Bugs +description: Used to report low severity bugs in llama.cpp (e.g. cosmetic issues, non critical UI glitches) +title: "Bug: " +labels: ["bug-unconfirmed", "low severity"] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + Please include information about your system, the steps to reproduce the bug, + and the version of llama.cpp that you are using. + If possible, please provide a minimal code example that reproduces the bug. + - type: textarea + id: what-happened + attributes: + label: What happened? + description: Also tell us, what did you expect to happen? + placeholder: Tell us what you see! + validations: + required: true + - type: textarea + id: version + attributes: + label: Name and Version + description: Which executable and which version of our software are you running? (use `--version` to get a version string) + placeholder: | + $./main --version + version: 2999 (42b4109e) + built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: What operating system are you seeing the problem on? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. + render: shell diff --git a/.github/ISSUE_TEMPLATE/02-bug-medium.yml b/.github/ISSUE_TEMPLATE/02-bug-medium.yml new file mode 100644 index 0000000000000..e8297eea03551 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/02-bug-medium.yml @@ -0,0 +1,50 @@ +name: Medium Severity Bug +description: Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but generally still useable) +title: "Bug: " +labels: ["bug-unconfirmed", "medium severity"] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + Please include information about your system, the steps to reproduce the bug, + and the version of llama.cpp that you are using. + If possible, please provide a minimal code example that reproduces the bug. + - type: textarea + id: what-happened + attributes: + label: What happened? + description: Also tell us, what did you expect to happen? + placeholder: Tell us what you see! + validations: + required: true + - type: textarea + id: version + attributes: + label: Name and Version + description: Which executable and which version of our software are you running? (use `--version` to get a version string) + placeholder: | + $./main --version + version: 2999 (42b4109e) + built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: What operating system are you seeing the problem on? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. + render: shell diff --git a/.github/ISSUE_TEMPLATE/03-bug-high.yml b/.github/ISSUE_TEMPLATE/03-bug-high.yml new file mode 100644 index 0000000000000..3c9d50d169720 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/03-bug-high.yml @@ -0,0 +1,50 @@ +name: High Severity Bug +description: Used to report high severity bugs in llama.cpp (e.g. Malfunctioning features hindering important common workflow) +title: "Bug: " +labels: ["bug-unconfirmed", "high severity"] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + Please include information about your system, the steps to reproduce the bug, + and the version of llama.cpp that you are using. + If possible, please provide a minimal code example that reproduces the bug. + - type: textarea + id: what-happened + attributes: + label: What happened? + description: Also tell us, what did you expect to happen? + placeholder: Tell us what you see! + validations: + required: true + - type: textarea + id: version + attributes: + label: Name and Version + description: Which executable and which version of our software are you running? (use `--version` to get a version string) + placeholder: | + $./main --version + version: 2999 (42b4109e) + built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: What operating system are you seeing the problem on? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. + render: shell diff --git a/.github/ISSUE_TEMPLATE/04-bug-critical.yml b/.github/ISSUE_TEMPLATE/04-bug-critical.yml new file mode 100644 index 0000000000000..d089d5fa10cfc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/04-bug-critical.yml @@ -0,0 +1,50 @@ +name: Critical Severity Bug +description: Used to report critical severity bugs in llama.cpp (e.g. Crashing, Corrupted, Dataloss) +title: "Bug: " +labels: ["bug-unconfirmed", "critical severity"] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + Please include information about your system, the steps to reproduce the bug, + and the version of llama.cpp that you are using. + If possible, please provide a minimal code example that reproduces the bug. + - type: textarea + id: what-happened + attributes: + label: What happened? + description: Also tell us, what did you expect to happen? + placeholder: Tell us what you see! + validations: + required: true + - type: textarea + id: version + attributes: + label: Name and Version + description: Which executable and which version of our software are you running? (use `--version` to get a version string) + placeholder: | + $./main --version + version: 2999 (42b4109e) + built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: What operating system are you seeing the problem on? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. + render: shell diff --git a/.github/ISSUE_TEMPLATE/05-enhancement.yml b/.github/ISSUE_TEMPLATE/05-enhancement.yml new file mode 100644 index 0000000000000..58fca73183d41 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/05-enhancement.yml @@ -0,0 +1,51 @@ +name: Enhancement +description: Used to request enhancements for llama.cpp +title: "Feature Request: " +labels: ["enhancement"] +body: + - type: markdown + attributes: + value: | + [Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggerganov/llama.cpp/discussions/categories/ideas) + + - type: checkboxes + id: prerequisites + attributes: + label: Prerequisites + description: Please confirm the following before submitting your enhancement request. + options: + - label: I am running the latest code. Mention the version if possible as well. + required: true + - label: I carefully followed the [README.md](https://github.com/ggerganov/llama.cpp/blob/master/README.md). + required: true + - label: I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed). + required: true + - label: I reviewed the [Discussions](https://github.com/ggerganov/llama.cpp/discussions), and have a new and useful enhancement to share. + required: true + + - type: textarea + id: feature-description + attributes: + label: Feature Description + description: Please provide a detailed written description of what you were trying to do, and what you expected `llama.cpp` to do as an enhancement. + placeholder: Detailed description of the enhancement + validations: + required: true + + - type: textarea + id: motivation + attributes: + label: Motivation + description: Please provide a detailed written description of reasons why this feature is necessary and how it is useful to `llama.cpp` users. + placeholder: Explanation of why this feature is needed and its benefits + validations: + required: true + + - type: textarea + id: possible-implementation + attributes: + label: Possible Implementation + description: If you have an idea as to how it can be implemented, please write a detailed description. Feel free to give links to external sources or share visuals that might be helpful to understand the details better. + placeholder: Detailed description of potential implementation + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/06-question.yml b/.github/ISSUE_TEMPLATE/06-question.yml new file mode 100644 index 0000000000000..9d3ff4972383e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/06-question.yml @@ -0,0 +1,38 @@ +name: Question +description: Used to ask questions about llama.cpp +title: "Question: " +labels: ["question"] +body: + - type: markdown + attributes: + value: | + [Please search your question first in Discussion if you got a common general question.](https://github.com/ggerganov/llama.cpp/discussions/categories/q-a) + + - type: checkboxes + id: prerequisites + attributes: + label: Prerequisites + description: Please confirm the following before submitting your question. + options: + - label: I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed). + required: true + - label: I reviewed the [Discussions](https://github.com/ggerganov/llama.cpp/discussions), and have a new useful question to share that cannot be answered within Discussions. + required: true + + - type: textarea + id: background-description + attributes: + label: Background Description + description: Please provide a detailed written description of what you were trying to do, and what you expected `llama.cpp` to do as an question. + placeholder: Detailed description of your question + validations: + required: true + + - type: textarea + id: possible-answer + attributes: + label: Possible Answer + description: If you have some idea of possible answers you want to confirm, that would also be appreciated. + placeholder: Your idea of possible answers + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/07-refactor.yml b/.github/ISSUE_TEMPLATE/07-refactor.yml new file mode 100644 index 0000000000000..3a68d3d5355d6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/07-refactor.yml @@ -0,0 +1,28 @@ +name: Refactor (Maintainers) +description: Used to track refactoring opportunities +title: "Refactor: " +labels: ["refactor"] +body: + - type: markdown + attributes: + value: | + Don't forget to [check for existing refactor issue tickets](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered. + Also you may want to check [Pull request refactor label as well](https://github.com/ggerganov/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too. + + - type: textarea + id: background-description + attributes: + label: Background Description + description: Please provide a detailed written description of the pain points you are trying to solve. + placeholder: Detailed description behind your motivation to request refactor + validations: + required: true + + - type: textarea + id: possible-approaches + attributes: + label: Possible Refactor Approaches + description: If you have some idea of possible approaches to solve this problem. You may want to make it a todo list. + placeholder: Your idea of possible refactoring opportunity/approaches + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md deleted file mode 100644 index 49812832ca542..0000000000000 --- a/.github/ISSUE_TEMPLATE/bug.md +++ /dev/null @@ -1,11 +0,0 @@ ---- -name: Bug template -about: Used to report bugs in llama.cpp -labels: ["bug-unconfirmed"] -assignees: '' - ---- - -Please include information about your system, the steps to reproduce the bug, and the version of llama.cpp that you are using. If possible, please provide a minimal code example that reproduces the bug. - -If the bug concerns the server, please try to reproduce it first using the [server test scenario framework](https://github.com/ggerganov/llama.cpp/tree/master/examples/server/tests). diff --git a/.github/ISSUE_TEMPLATE/enhancement.md b/.github/ISSUE_TEMPLATE/enhancement.md deleted file mode 100644 index dcffda7500f52..0000000000000 --- a/.github/ISSUE_TEMPLATE/enhancement.md +++ /dev/null @@ -1,28 +0,0 @@ ---- -name: Enhancement template -about: Used to request enhancements for llama.cpp -labels: ["enhancement"] -assignees: '' - ---- - -# Prerequisites - -Please answer the following questions for yourself before submitting an issue. - -- [ ] I am running the latest code. Development is very rapid so there are no tagged versions as of now. -- [ ] I carefully followed the [README.md](https://github.com/ggerganov/llama.cpp/blob/master/README.md). -- [ ] I [searched using keywords relevant to my issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/filtering-and-searching-issues-and-pull-requests) to make sure that I am creating a new issue that is not already open (or closed). -- [ ] I reviewed the [Discussions](https://github.com/ggerganov/llama.cpp/discussions), and have a new bug or useful enhancement to share. - -# Feature Description - -Please provide a detailed written description of what you were trying to do, and what you expected `llama.cpp` to do as an enhancement. - -# Motivation - -Please provide a detailed written description of reasons why this feature is necessary and how it is useful to `llama.cpp` users. - -# Possible Implementation - -If you have an idea as to how it can be implemented, please write a detailed description. Feel free to give links to external sources or share visuals that might be helpful to understand the details better. diff --git a/CMakeLists.txt b/CMakeLists.txt index c5add8239c2bd..fbbc38644ef4b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -628,6 +628,10 @@ if (LLAMA_SYCL) add_compile_definitions(GGML_SYCL_F16) endif() + if (LLAMA_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_SYCL_FORCE_MMQ) + endif() + add_compile_options(-I./) #include DPCT add_compile_options(-I/${SYCL_INCLUDE_DIR}) diff --git a/CMakePresets.json b/CMakePresets.json index ad1af7eccebbd..e2b7a79e371bf 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -1,4 +1,4 @@ -{ +{ "version": 4, "configurePresets": [ { @@ -40,6 +40,10 @@ { "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] }, { "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "release" ] }, - { "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "release", "static" ] } + { "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "release", "static" ] }, + + { "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] }, + { "name": "x64-windows-msvc-release", "inherits": [ "base", "release" ] }, + { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "release", "static" ] } ] } diff --git a/Makefile b/Makefile index fe63cbd6063aa..5caf31cdf3737 100644 --- a/Makefile +++ b/Makefile @@ -441,6 +441,9 @@ endif # JETSON_EOL_MODULE_DETECT ifdef LLAMA_DEBUG MK_NVCCFLAGS += -lineinfo endif # LLAMA_DEBUG +ifdef LLAMA_CUDA_DEBUG + MK_NVCCFLAGS += --device-debug +endif # LLAMA_CUDA_DEBUG ifdef LLAMA_CUDA_NVCC NVCC = $(CCACHE) $(LLAMA_CUDA_NVCC) else diff --git a/README.md b/README.md index 15519c97f43c2..1cab7f19d596f 100644 --- a/README.md +++ b/README.md @@ -477,7 +477,8 @@ Building the program with BLAS support may lead to some performance improvements |--------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | - | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | + | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | + | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of dequantization + matrix multiplication kernels instead of leveraging Math libraries. | | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index b8ec488351d8f..159088c988ab8 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1265,6 +1265,17 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + # Apply to granite small models only + if self.hparams.get("vocab_size", 32000) == 49152: + self.gguf_writer.add_add_bos_token(False) + @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: @@ -1279,9 +1290,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") - if name.endswith("q_proj.weight"): + if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) - if name.endswith("k_proj.weight"): + if name.endswith(("k_proj.weight", "k_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) # process the experts separately @@ -2568,6 +2579,85 @@ def write_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("DeepseekV2ForCausalLM") +class DeepseekV2Model(Model): + model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["v_head_dim"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * hparams["rope_scaling"]["mscale_all_dim"]) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def write_tensors(self): + super().write_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + ###### CONVERSION LOGIC ###### diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 45bdad6897658..ca36313844c13 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -68,7 +68,7 @@ CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); -/** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */ +/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */ CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx); diff --git a/examples/server/public/index.html b/examples/server/public/index.html index 2961999f2451a..4c5a34d903309 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -594,7 +594,7 @@ message = html`<${Probabilities} data=${data} />` } else { const text = isArrayMessage ? - data.map(msg => msg.content).join('').replace(/^\s+/, '') : + data.map(msg => msg.content).join('') : data; message = isCompletionMode ? text : @@ -877,19 +877,30 @@ // poor mans markdown replacement const Markdownish = (params) => { - const md = params.text - .replace(/&/g, '&') - .replace(//g, '>') - .replace(/(^|\n)#{1,6} ([^\n]*)(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1

$2

') - .replace(/\*\*(.*?)\*\*(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') - .replace(/__(.*?)__(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') - .replace(/\*(.*?)\*(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') - .replace(/_(.*?)_(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') - .replace(/```.*?\n([\s\S]*?)```/g, '
$1
') - .replace(/`(.*?)`/g, '$1') - .replace(/\n/gim, '
'); - return html``; + const chunks = params.text.split('```'); + + for (let i = 0; i < chunks.length; i++) { + if (i % 2 === 0) { // outside code block + chunks[i] = chunks[i] + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/(^|\n)#{1,6} ([^\n]*)(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1

$2

') + .replace(/\*\*(.*?)\*\*(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') + .replace(/__(.*?)__(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') + .replace(/\*(.*?)\*(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') + .replace(/_(.*?)_(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') + .replace(/```.*?\n([\s\S]*?)```/g, '
$1
') + .replace(/`(.*?)`/g, '$1') + .replace(/\n/gim, '
'); + } else { // inside code block + chunks[i] = `
${chunks[i]}
`; + } + } + + const restoredText = chunks.join(''); + + return html``; }; const ModelGenerationInfo = (params) => { @@ -903,6 +914,7 @@ ` } + // simple popover impl const Popover = (props) => { const isOpen = useSignal(false); @@ -1054,4 +1066,3 @@

llama.cpp

- diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b82167cbf7227..d0a754ee11b67 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -119,6 +119,20 @@ int ggml_cuda_get_device() { return id; } +static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { + ggml_cuda_set_device(device); +#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA) + auto res = hipMallocManaged(ptr, size); + if (res == hipSuccess) { + // if error we "need" to know why... + CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); + } + return res; +#else + return cudaMalloc(ptr, size); +#endif +} + static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: @@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); - CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); + CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC @@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0 void * dev_ptr; - cudaError_t err = cudaMalloc(&dev_ptr, size); + cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device); if (err != cudaSuccess) { // clear the error cudaGetLastError(); @@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_cuda_set_device(id); char * buf; - CUDA_CHECK(cudaMalloc(&buf, size)); + CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id)); // set padding to 0 to avoid possible NaN values if (size > original_size) { @@ -2510,9 +2524,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool use_cuda_graph = true; bool cuda_graph_update_required = false; - // pointer to CUDA cpy kernel, which is required to identify + // vector of pointers to CUDA cpy kernels, which are required to identify // kernel parameters which need updated in the graph for each token - void * ggml_cuda_cpy_fn_ptr = nullptr; + std::vector ggml_cuda_cpy_fn_ptrs; if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { @@ -2588,9 +2602,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); - if (ggml_cuda_cpy_fn_ptr == nullptr) { - // store a pointer to the copy op CUDA kernel to identify it later - ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + // store a pointer to each copy op CUDA kernel to identify it later + void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { + ggml_cuda_cpy_fn_ptrs.push_back(ptr); } } @@ -2720,7 +2735,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured int k = 0; for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) { + if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 8f6fd71cfea35..22872ca5c1d81 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -79,13 +79,8 @@ #define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostUnregister hipHostUnregister #define cudaLaunchHostFunc hipLaunchHostFunc -#ifdef GGML_HIP_UMA -#define cudaMalloc hipMallocManaged -#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size) -#else #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) -#endif #define cudaMemcpy hipMemcpy #define cudaMemcpyAsync hipMemcpyAsync #define cudaMemcpyPeerAsync hipMemcpyPeerAsync diff --git a/ggml-cuda/concat.cu b/ggml-cuda/concat.cu index 2941d2f1770a8..fb9dee8f8cee5 100644 --- a/ggml-cuda/concat.cu +++ b/ggml-cuda/concat.cu @@ -1,15 +1,68 @@ #include "concat.cuh" -static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) { +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; } - // operation + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + + if (nidx < ne00) { // src0 + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * ne00 * gridDim.y; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + (nidx - ne00) + + blockIdx.y * (ne0 - ne00) + + blockIdx.z * (ne0 - ne00) * gridDim.y; + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + + if (blockIdx.y < ne01) { // src0 + int offset_src = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + + (blockIdx.y - ne01) * ne0 + + blockIdx.z * ne0 * (gridDim.y - ne01); + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + int offset_dst = nidx + blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; + if (blockIdx.z < ne02) { // src0 int offset_src = nidx + @@ -25,25 +78,53 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst, } } -static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) { +static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2); - concat_f32<<>>(x, y, dst, ne0, ne02); + if (dim == 0) { + concat_f32_dim0<<>>(x, y, dst, ne0, ne00); + return; + } + if (dim == 1) { + concat_f32_dim1<<>>(x, y, dst, ne0, ne01); + return; + } + concat_f32_dim2<<>>(x, y, dst, ne0, ne02); } void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], stream); + if (dim != 3) { + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); + } + } else { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + + CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); } } diff --git a/ggml-metal.m b/ggml-metal.m index c9e570dbf5a3a..4ba498e87f9d0 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -35,6 +35,10 @@ GGML_METAL_KERNEL_TYPE_MUL_ROW, GGML_METAL_KERNEL_TYPE_DIV, GGML_METAL_KERNEL_TYPE_DIV_ROW, + GGML_METAL_KERNEL_TYPE_REPEAT_F32, + GGML_METAL_KERNEL_TYPE_REPEAT_F16, + GGML_METAL_KERNEL_TYPE_REPEAT_I32, + GGML_METAL_KERNEL_TYPE_REPEAT_I16, GGML_METAL_KERNEL_TYPE_SCALE, GGML_METAL_KERNEL_TYPE_SCALE_4, GGML_METAL_KERNEL_TYPE_CLAMP, @@ -184,9 +188,9 @@ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -485,6 +489,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); @@ -634,9 +642,9 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_REPEAT: case GGML_OP_SCALE: case GGML_OP_CLAMP: case GGML_OP_SQR: @@ -770,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: + if (op->src[0]->ne[0] == 256) { + return false; + } return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -976,10 +988,10 @@ static enum ggml_status ggml_metal_graph_compute( switch (dst->op) { case GGML_OP_CONCAT: { - const int64_t nb = ne00; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + const int32_t dim = ((int32_t *) dst->op_params)[0]; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1008,7 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; const int nth = MIN(1024, ne0); @@ -1018,11 +1030,14 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_MUL: case GGML_OP_DIV: { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + const size_t offs = 0; bool bcast_row = false; - int64_t nb = ne00; + int64_t nb = ne00; // used by the "row" kernels id pipeline = nil; @@ -1091,6 +1106,42 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; + case GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: GGML_ASSERT(false); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ACC: { GGML_ASSERT(src0t == GGML_TYPE_F32); @@ -2573,7 +2624,7 @@ static enum ggml_status ggml_metal_graph_compute( case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); @@ -2586,7 +2637,7 @@ static enum ggml_status ggml_metal_graph_compute( switch (ne00) { case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); diff --git a/ggml-metal.metal b/ggml-metal.metal index 8ff70d7a79ca7..b16f2b7e0c74f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -168,6 +168,53 @@ kernel void kernel_div( } } +template +kernel void kernel_repeat( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 % ne03; + const int64_t i02 = i2 % ne02; + const int64_t i01 = i1 % ne01; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i00 = i0 % ne00; + *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + } +} + +typedef decltype(kernel_repeat) kernel_repeat_t; + +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; + // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row( @@ -2418,7 +2465,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; +//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( @@ -2696,7 +2743,7 @@ kernel void kernel_flash_attn_ext_vec_f16( } template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; +//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; kernel void kernel_cpy_f16_f16( device const half * src0, @@ -3319,31 +3366,30 @@ kernel void kernel_concat( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, + constant int32_t & dim, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + device const float * x; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i02 < ne02) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; - src0_ptr += ntg.x*nb00; + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; - src1_ptr += ntg.x*nb10; + x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - dst_ptr += ntg.x*nb0; + + device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; } } diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index cc1d3ace1ddac..49a20df4bd85e 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #ifdef _WIN32 @@ -47,6 +48,7 @@ struct socket_t { sockfd_t fd; socket_t(sockfd_t fd) : fd(fd) {} ~socket_t() { + GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd); #ifdef _WIN32 closesocket(this->fd); #else @@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() { } struct ggml_backend_rpc_buffer_type_context { - std::shared_ptr sock; + std::string endpoint; std::string name; size_t alignment; size_t max_size; @@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context { struct ggml_backend_rpc_context { std::string endpoint; std::string name; - std::shared_ptr sock; - ggml_backend_buffer_type_t buft; }; struct ggml_backend_rpc_buffer_context { @@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) { return true; } -static bool parse_endpoint(const char * endpoint, std::string & host, int & port) { - std::string str(endpoint); - size_t pos = str.find(':'); +static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { + size_t pos = endpoint.find(':'); if (pos == std::string::npos) { return false; } - host = str.substr(0, pos); - port = std::stoi(str.substr(pos + 1)); + host = endpoint.substr(0, pos); + port = std::stoi(endpoint.substr(pos + 1)); return true; } @@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation +static std::shared_ptr get_socket(const std::string & endpoint) { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unordered_map> sockets; + static bool initialized = false; + + auto it = sockets.find(endpoint); + if (it != sockets.end()) { + if (auto sock = it->second.lock()) { + return sock; + } + } + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + return nullptr; + } +#ifdef _WIN32 + if (!initialized) { + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return nullptr; + } + initialized = true; + } +#else + UNUSED(initialized); +#endif + auto sock = socket_connect(host.c_str(), port); + if (sock == nullptr) { + return nullptr; + } + GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); + sockets[endpoint] = sock; + return sock; +} + GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; return ctx->name.c_str(); @@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer std::vector input(input_size, 0); memcpy(input.data(), &size, sizeof(size)); std::vector output; - bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output); + auto sock = get_socket(buft_ctx->endpoint); + bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | @@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer if (remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"}, + new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"}, remote_size); return buffer; } else { @@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - return buft_ctx->sock == rpc_ctx->sock; + return buft_ctx->endpoint == rpc_ctx->endpoint; } static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { @@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { /* .is_host = */ NULL, }; - GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; @@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context; - delete buft_ctx; - delete rpc_ctx->buft; delete rpc_ctx; delete backend; } GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; - return ctx->buft; + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); } GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { @@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t std::vector input; serialize_graph(cgraph, input); std::vector output; - bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output); + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 1); return (enum ggml_status)output[0]; @@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .event_synchronize = */ NULL, }; -static std::unordered_map instances; - GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { - ggml_backend_t backend = ggml_backend_rpc_init(endpoint); - return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr; -} - -GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { - std::string endpoint_str(endpoint); - if (instances.find(endpoint_str) != instances.end()) { - return instances[endpoint_str]; - } -#ifdef _WIN32 - { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - return nullptr; - } - } -#endif - fprintf(stderr, "Connecting to %s\n", endpoint); - std::string host; - int port; - if (!parse_endpoint(endpoint, host, port)) { - return nullptr; - } - auto sock = socket_connect(host.c_str(), port); + static std::mutex mutex; + std::lock_guard lock(mutex); + // NOTE: buffer types are allocated and never freed; this is by design + static std::unordered_map buft_map; + auto it = buft_map.find(endpoint); + if (it != buft_map.end()) { + return it->second; + } + auto sock = get_socket(endpoint); if (sock == nullptr) { return nullptr; } size_t alignment = get_alignment(sock); size_t max_size = get_max_size(sock); ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { - /* .sock = */ sock, - /* .name = */ "RPC" + std::to_string(sock->fd), + /* .endpoint = */ endpoint, + /* .name = */ "RPC[" + std::string(endpoint) + "]", /* .alignment = */ alignment, - /* .max_size = */ max_size + /* .max_size = */ max_size }; ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { /* .iface = */ ggml_backend_rpc_buffer_type_interface, /* .context = */ buft_ctx }; + buft_map[endpoint] = buft; + return buft; +} +GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC" + std::to_string(sock->fd), - /* .sock = */ sock, - /* .buft = */ buft + /* .endpoint = */ endpoint, + /* .name = */ "RPC", }; - instances[endpoint] = new ggml_backend { + ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_rpc_guid(), /* .interface = */ ggml_backend_rpc_interface, /* .context = */ ctx }; - - return instances[endpoint]; + return backend; } GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { @@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr & sock, size_t * f } GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { - ggml_backend_t backend = ggml_backend_rpc_init(endpoint); - if (backend == nullptr) { + auto sock = get_socket(endpoint); + if (sock == nullptr) { *free = 0; *total = 0; return; } - ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; - get_device_memory(ctx->sock, free, total); + get_device_memory(sock, free, total); } // RPC server-side implementation diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 496ec61c3c28a..a73448136a4d8 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -2944,6 +2944,57 @@ namespace dpct using shared_memory = detail::device_memory; + template + inline T atomic_fetch_add(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); + } + + template + inline T1 atomic_fetch_add(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); + } + + template + inline T atomic_fetch_add(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_add(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_add(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_add(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } + } + + template + inline T1 atomic_fetch_add(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_add(addr, operand, memoryOrder); + } + } // COPY from DPCT head files #define GGML_COMMON_DECL_SYCL @@ -2971,20 +3022,19 @@ static int g_work_group_size = 0; // typedef sycl::half ggml_fp16_t; #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP -#define VER_4VEC 610 //todo for hardward optimize. +#define VER_4VEC 130 //todo for hardward optimize. #define VER_GEN9 700 //todo for hardward optimize. #define VER_GEN12 1000000 //todo for hardward optimize. #define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize. #define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares - -//define for XMX in Intel GPU -//TODO: currently, it's not used for XMX really. -#define SYCL_USE_XMX +#if !defined(GGML_SYCL_FORCE_MMQ) + #define SYCL_USE_XMX +#endif // max batch size to use MMQ kernels when tensor cores are available -#define XMX_MAX_BATCH_SIZE 32 +#define MMQ_MAX_BATCH_SIZE 32 #if defined(_MSC_VER) @@ -3060,6 +3110,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d bool ggml_backend_is_sycl(ggml_backend_t backend); int ggml_backend_sycl_get_device(ggml_backend_t backend); int get_main_device(); +static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer); void print_ggml_tensor(const char*name, struct ggml_tensor *src); void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt); @@ -8830,12 +8881,11 @@ static void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template +template static void rope_neox( const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims -, - const sycl::nd_item<3> &item_ct1) { + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, + const float * freq_factors, const sycl::nd_item<3> &item_ct1) { const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); @@ -8863,8 +8913,10 @@ static void rope_neox( float cur_rot = inv_ndims * ic - ib; const int p = has_pos ? pos[i2] : 0; + const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; + const float theta_base = - p * freq_scale * dpct::pow(theta_scale, col / 2.0f); + p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor; float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -12413,7 +12465,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, const int32_t *pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, - dpct::queue_ptr stream) { + const float * freq_factors, dpct::queue_ptr stream) { GGML_ASSERT(ncols % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); @@ -12423,38 +12475,48 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, const float inv_ndims = -1.0f / n_dims; if (pos == nullptr) { - /* - DPCT1049:42: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ncols, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, inv_ndims, - item_ct1); - }); + if (freq_factors == nullptr) { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ncols, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, inv_ndims, freq_factors, + item_ct1); + }); + } else { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ncols, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, inv_ndims, freq_factors, + item_ct1); + }); + } } else { - /* - DPCT1049:43: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ncols, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, inv_ndims, item_ct1); - }); + if (freq_factors == nullptr) { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ncols, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1); + }); + } else { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ncols, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1); + }); + } } } @@ -13501,6 +13563,10 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0, const float *src0_dd, const float *src1_dd, float *dst_dd, const dpct::queue_ptr &main_stream) { +#pragma message("TODO: generalize concat kernel for dim != 2") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563") + int dim = dst->op_params[0]; + GGML_ASSERT(dim == 2); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -13986,9 +14052,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const dpct::queue_ptr &main_stream) { -#pragma message("TODO: implement phi3 frequency factors support") -#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") - GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); + const ggml_tensor * src2 = dst->src[2]; GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -14014,6 +14078,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const float * freq_factors = nullptr; const int32_t * pos = nullptr; if ((mode & 1) == 0) { GGML_ASSERT(src1->type == GGML_TYPE_I32); @@ -14024,6 +14089,16 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, const bool is_neox = mode & 2; const bool is_glm = mode & 4; + if (is_neox) { + pos = (const int32_t *) src1_dd; + + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; + } + } else { + GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox"); + } + rope_corr_dims corr_dims; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); @@ -14035,13 +14110,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, if (src0->type == GGML_TYPE_F32) { rope_neox_sycl( (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, main_stream + attn_factor, corr_dims, freq_factors, main_stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, - main_stream); + freq_factors, main_stream); } else { GGML_ASSERT(false); } @@ -15173,6 +15248,29 @@ catch (sycl::exception const &exc) { std::exit(1); } +inline bool ggml_sycl_supports_mmq(enum ggml_type type) { + // TODO: accuracy issues in MMQ + return false; +} + +bool ggml_sycl_supports_dmmv(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_F16: + return true; + default: + return false; + } +} static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = @@ -15189,75 +15287,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } -#ifdef SYCL_USE_XMX - const bool use_xmx = true; -#else - const bool use_xmx = false; -#endif + // check data types and tensor shapes for custom matrix multiplication kernels: + bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - // debug helpers - //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); - //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); - //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); - //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); - //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); - //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); + bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + // mmvq and mmq need the __dp4a instruction which is available for gen12+ + // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e + use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); +#ifdef SYCL_USE_XMX + use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); +#endif // SYCL_USE_XMX - if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // KQ single-batch - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n"); ggml_sycl_mul_mat_vec_p021(src0, src1, dst); - } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n"); ggml_sycl_mul_mat_vec_nc(src0, src1, dst); - } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n"); ggml_sycl_mul_mat_batched_sycl(src0, src1, dst); - } else if (src0->type == GGML_TYPE_F32) { - // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); - } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n"); - if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) { -#ifdef GGML_SYCL_FORCE_DMMV - const bool use_mul_mat_vec_q = false; -#else - bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); - use_mul_mat_vec_q = use_mul_mat_vec_q || - (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) || - (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) || - (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) || - (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M); - - -#endif // GGML_SYCL_FORCE_DMMV - - if (use_mul_mat_vec_q) { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); - } else { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); - } - } else { - bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); - - if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) { - use_mul_mat_q = false; - } - - if (use_mul_mat_q) { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true); - } else { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); - } - } + } else if (use_dequantize_mul_mat_vec) { + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); + } else if (use_mul_mat_vec_q) { + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); + } else if (use_mul_mat_q) { + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true); } else { - GGML_ASSERT(false); + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); } } @@ -15434,22 +15499,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) { } #endif +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +__dpct_inline__ static void k_copy_src1_to_contiguous( + const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, + int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping, + const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + int64_t ne11, int64_t ne10, size_t nb11, size_t nb12, + const sycl::nd_item<3> &item_ct1, int &src1_row) { + int32_t iid1 = item_ct1.get_group(2); + int32_t id = item_ct1.get_group(1); + + const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + + if (row_id_i != i02) { + return; + } + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + if (item_ct1.get_local_id(2) == 0) { + src1_row = + dpct::atomic_fetch_add( + cur_src1_row, 1); + row_mapping[src1_row] = {id, iid1}; + } + /* + DPCT1065:194: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); + float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); + +#pragma unroll + for (int i = item_ct1.get_local_id(2); i < ne10; + i += item_ct1.get_local_range(2)) { + src1_row_contiguous[i] = src1_row_original[i]; + } +} + +__dpct_inline__ static void k_copy_dst_from_contiguous( + char *__restrict__ dst_original, const char *__restrict__ dst_contiguous, + const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1, + size_t nb2, const sycl::nd_item<3> &item_ct1) { + int32_t i = item_ct1.get_group(2); + + const int32_t i1 = row_mapping[i].i1; + const int32_t i2 = row_mapping[i].i2; + + const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); + float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); + +#pragma unroll + for (int j = item_ct1.get_local_id(2); j < ne0; + j += item_ct1.get_local_range(2)) { + dst_row_original[j] = dst_row_contiguous[j]; + } +} + static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { - GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT && - "mul_mat_id does not support split buffers"); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers"); + const ggml_tensor *ids = dst->src[2]; - const dpct::queue_ptr stream = g_syclStreams[g_main_device][0]; + GGML_TENSOR_BINARY_OP_LOCALS - const size_t nb11 = src1->nb[1]; - const size_t nb1 = dst->nb[1]; + const dpct::queue_ptr stream = g_syclStreams[g_main_device][0]; - const int32_t id = ((int32_t *)dst->op_params)[0]; - const int32_t n_as = src0->ne[2]; + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; std::vector ids_host(ggml_nbytes(ids)); - const char *ids_dev = (const char *)ids->data; + const char * ids_dev = (const char *) ids->data; SYCL_CHECK(CHECK_TRY_ERROR( stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); @@ -15489,24 +15618,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, src0_row.ne[2] = 1; src0_row.ne[3] = 1; - src0_row.nb[3] = src0->nb[2]; - - if (src1->ne[1] == 1) { - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id = - *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] + - id * ids->nb[0]); - - GGML_ASSERT(row_id >= 0 && row_id < n_as); + src0_row.nb[3] = nb02; + + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + if (ne12 == 1) { + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = id; + const int64_t i2 = i12; src0_row_extra.data_device[g_main_device] = - src0_original + row_id * src0->nb[2]; + src0_original + i02*nb02; src1_row_extra.data_device[g_main_device] = - src1_original + i01 * src1->nb[1]; + src1_original + + i11*nb11 + i12*nb12; dst_row_extra.data_device[g_main_device] = - dst_original + i01 * dst->nb[1]; + dst_original + i1*nb1 + i2*nb2; ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row); + } } } else { sycl_pool_alloc src1_contiguous(sizeof(float)*ggml_nelements(src1)); @@ -15515,64 +15660,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, src1_row_extra.data_device[g_main_device] = src1_contiguous.get(); dst_row_extra.data_device[g_main_device] = dst_contiguous.get(); - for (int32_t row_id = 0; row_id < n_as; ++row_id) { + for (int64_t i02 = 0; i02 < n_as; i02++) { int64_t num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - if (row_id_i != row_id) { - continue; - } + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); - GGML_ASSERT(row_id >= 0 && row_id < n_as); + if (row_id_i != i02) { + continue; + } - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11, - src1_original + i01 * nb11, nb11))); - num_src1_rows++; + num_src1_rows++; + } } if (num_src1_rows == 0) { continue; } - src0_row_extra.data_device[g_main_device] = - src0_original + row_id * src0->nb[2]; + sycl_pool_alloc dev_cur_src1_row(1); + sycl_pool_alloc dev_row_mapping(num_src1_rows); + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); + + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u)); + sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor src1_row_acc(cgh); + + char *__restrict src1_contiguous_get = + src1_contiguous.get(); + int *__restrict dev_cur_src1_row_get = + dev_cur_src1_row.get(); + mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + size_t ids_nb_ct6 = ids->nb[1]; + size_t ids_nb_ct7 = ids->nb[0]; + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_src1_to_contiguous( + src1_original, src1_contiguous_get, + dev_cur_src1_row_get, + dev_row_mapping_get, ids_dev, i02, + ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12, + item_ct1, src1_row_acc); + }); + }); + } + + src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02; + + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); src1_row.ne[1] = num_src1_rows; - dst_row.ne[1] = num_src1_rows; src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row); - num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); - - if (row_id_i != row_id) { - continue; - } - - GGML_ASSERT(row_id >= 0 && row_id < n_as); - - SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy( - dst_original + i01 * nb1, - dst_contiguous.get() + num_src1_rows * nb1, nb1))); - num_src1_rows++; + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u)); + sycl::range<3> grid_dims(1, 1, num_src1_rows); + stream->submit([&](sycl::handler &cgh) { + const char *__restrict dst_contiguous_get = + dst_contiguous.get(); + const mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_dst_from_contiguous(dst_original, + dst_contiguous_get, + dev_row_mapping_get, + ne0, nb1, nb2, item_ct1); + }); + }); } } } - - if (dst->backend == GGML_BACKEND_TYPE_CPU) { - SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); - } } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -16555,10 +16734,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe UNUSED(buffer); } -// unused at the moment -//static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) { -// return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name; -//} +static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) { + return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name; +} GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 79ce1479f16ca..92e622b043177 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -6012,6 +6012,8 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { }; GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { + ggml_vk_instance_init(); + #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_backend_vk_buffer_type(" << dev_num << ")" << std::endl; #endif diff --git a/ggml.c b/ggml.c index 5145ceec9f4b2..023077ca6e89b 100644 --- a/ggml.c +++ b/ggml.c @@ -4882,10 +4882,21 @@ struct ggml_tensor * ggml_repeat_back( // ggml_concat struct ggml_tensor * ggml_concat( - struct ggml_context* ctx, - struct ggml_tensor* a, - struct ggml_tensor* b) { - GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]); + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int dim) { + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + + int64_t ne[GGML_MAX_DIMS]; + for (int d = 0; d < GGML_MAX_DIMS; ++d) { + if (d == dim) { + ne[d] = a->ne[d] + b->ne[d]; + continue; + } + GGML_ASSERT(a->ne[d] == b->ne[d]); + ne[d] = a->ne[d]; + } bool is_node = false; @@ -4893,7 +4904,9 @@ struct ggml_tensor * ggml_concat( is_node = true; } - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]); + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); + + ggml_set_op_params_i32(result, 0, dim); result->op = GGML_OP_CONCAT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -5013,6 +5026,7 @@ struct ggml_tensor * ggml_leaky_relu( } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); result->op = GGML_OP_LEAKY_RELU; @@ -10967,26 +10981,29 @@ static void ggml_compute_forward_concat_f32( GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const float * x; + + // TODO: smarter multi-theading for (int i3 = 0; i3 < ne3; i3++) { for (int i2 = ith; i2 < ne2; i2 += nth) { - if (i2 < ne02) { // src0 - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03); - - float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); - *y = *x; - } - } - } // src1 - else { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13); - - float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); - *y = *x; + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); } + + float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; } } } @@ -10994,7 +11011,7 @@ static void ggml_compute_forward_concat_f32( } static void ggml_compute_forward_concat( - const struct ggml_compute_params* params, + const struct ggml_compute_params * params, struct ggml_tensor* dst) { const struct ggml_tensor * src0 = dst->src[0]; diff --git a/ggml.h b/ggml.h index f803ba7241fe1..4e6bcb30fd931 100644 --- a/ggml.h +++ b/ggml.h @@ -1007,12 +1007,13 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); - // concat a and b on dim 2 + // concat a and b along dim // used in stable-diffusion GGML_API struct ggml_tensor * ggml_concat( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + int dim); GGML_API struct ggml_tensor * ggml_abs( struct ggml_context * ctx, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0961ea9ed7822..d5b31eb045d71 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -31,17 +31,21 @@ class General: FILE_TYPE = "general.file_type" class LLM: - VOCAB_SIZE = "{arch}.vocab_size" - CONTEXT_LENGTH = "{arch}.context_length" - EMBEDDING_LENGTH = "{arch}.embedding_length" - BLOCK_COUNT = "{arch}.block_count" - FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" - USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" - TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" - EXPERT_COUNT = "{arch}.expert_count" - EXPERT_USED_COUNT = "{arch}.expert_used_count" - POOLING_TYPE = "{arch}.pooling_type" - LOGIT_SCALE = "{arch}.logit_scale" + VOCAB_SIZE = "{arch}.vocab_size" + CONTEXT_LENGTH = "{arch}.context_length" + EMBEDDING_LENGTH = "{arch}.embedding_length" + BLOCK_COUNT = "{arch}.block_count" + LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" + FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" + USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" + TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + EXPERT_COUNT = "{arch}.expert_count" + EXPERT_USED_COUNT = "{arch}.expert_used_count" + EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" + EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + POOLING_TYPE = "{arch}.pooling_type" + LOGIT_SCALE = "{arch}.logit_scale" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -53,6 +57,8 @@ class Attention: LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" CAUSAL = "{arch}.attention.causal" + Q_LORA_RANK = "{arch}.attention.q_lora_rank" + KV_LORA_RANK = "{arch}.attention.kv_lora_rank" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -62,6 +68,7 @@ class Rope: SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" class SSM: CONV_KERNEL = "{arch}.ssm.conv_kernel" @@ -141,6 +148,7 @@ class MODEL_ARCH(IntEnum): DBRX = auto() OLMO = auto() ARCTIC = auto() + DEEPSEEK2 = auto() class MODEL_TENSOR(IntEnum): @@ -186,6 +194,12 @@ class MODEL_TENSOR(IntEnum): SSM_A = auto() SSM_D = auto() SSM_OUT = auto() + ATTN_Q_A = auto() + ATTN_Q_B = auto() + ATTN_KV_A_MQA = auto() + ATTN_KV_B = auto() + ATTN_Q_A_NORM = auto() + ATTN_KV_A_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -222,6 +236,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.ARCTIC: "arctic", + MODEL_ARCH.DEEPSEEK2: "deepseek2", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -267,6 +282,12 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", + MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", + MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", + MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", + MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -758,6 +779,33 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.DEEPSEEK2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], # TODO } @@ -791,6 +839,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DEEPSEEK2: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index c194dd5dd1e65..b93747aff58b3 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -374,9 +374,15 @@ def add_embedding_length(self, length: int) -> None: def add_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length) + def add_leading_dense_block_count(self, length: int) -> None: + self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length) + def add_feed_forward_length(self, length: int) -> None: self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_expert_feed_forward_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_parallel_residual(self, use: bool) -> None: self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) @@ -407,6 +413,12 @@ def add_expert_count(self, count: int) -> None: def add_expert_used_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count) + def add_expert_shared_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count) + + def add_expert_weights_scale(self, value: float) -> None: + self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) @@ -416,6 +428,12 @@ def add_layer_norm_rms_eps(self, value: float) -> None: def add_causal_attention(self, value: bool) -> None: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) + def add_q_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length) + + def add_kv_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) @@ -440,6 +458,9 @@ def add_rope_scaling_orig_ctx_len(self, value: int) -> None: def add_rope_scaling_finetuned(self, value: bool) -> None: self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value) + def add_rope_scaling_yarn_log_mul(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value) + def add_ssm_conv_kernel(self, value: int) -> None: self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8b1b21d78bb09..83e3c4c3381a0 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -256,6 +256,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_UP_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2 ), # AWQ-activation gate @@ -285,6 +286,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2 ), # Feed-forward down @@ -320,6 +322,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_DOWN_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2 ), MODEL_TENSOR.ATTN_Q_NORM: ( @@ -383,6 +386,30 @@ class TensorNameMap: "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", ), + + MODEL_TENSOR.ATTN_Q_A: ( + "model.layers.{bid}.self_attn.q_a_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_Q_B: ( + "model.layers.{bid}.self_attn.q_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_A_MQA: ( + "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_B: ( + "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_Q_A_NORM: ( + "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_A_NORM: ( + "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 + ), } # architecture-specific block mappings @@ -415,7 +442,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): if tensor not in MODEL_TENSORS[arch]: continue # TODO: make this configurable - n_experts = 128 + n_experts = 160 for xid in range(n_experts): tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) self.mapping[tensor_name] = (tensor, tensor_name) diff --git a/llama.cpp b/llama.cpp index 5b9dbff2c46d2..1eb70ea5bd4f8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -103,7 +103,7 @@ #endif #define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 128 +#define LLAMA_MAX_EXPERTS 160 // // logging @@ -222,6 +222,7 @@ enum llm_arch { LLM_ARCH_DBRX, LLM_ARCH_OLMO, LLM_ARCH_ARCTIC, + LLM_ARCH_DEEPSEEK2, LLM_ARCH_UNKNOWN, }; @@ -259,6 +260,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -279,11 +281,15 @@ enum llm_kv { LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, @@ -296,6 +302,8 @@ enum llm_kv { LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -305,6 +313,7 @@ enum llm_kv { LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_ROPE_SCALING_YARN_LOG_MUL, LLM_KV_SPLIT_NO, LLM_KV_SPLIT_COUNT, @@ -355,17 +364,21 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, - { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, - { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, - { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, - { LLM_KV_BLOCK_COUNT, "%s.block_count" }, - { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, - { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, - { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, - { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, - { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, - { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -376,6 +389,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -385,6 +400,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -478,6 +494,12 @@ enum llm_tensor { LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1061,6 +1083,35 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_DEEPSEEK2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1745,6 +1796,7 @@ enum e_model { MODEL_13B, MODEL_14B, MODEL_15B, + MODEL_16B, MODEL_20B, MODEL_30B, MODEL_34B, @@ -1752,6 +1804,7 @@ enum e_model { MODEL_40B, MODEL_65B, MODEL_70B, + MODEL_236B, MODEL_314B, MODEL_SMALL, MODEL_MEDIUM, @@ -1787,6 +1840,13 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_layer_dense_lead = 0; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_ff_exp = 0; + uint32_t n_expert_shared = 0; + float expert_weights_scale = 0.0; + float f_norm_eps; float f_norm_rms_eps; @@ -1794,6 +1854,7 @@ struct llama_hparams { float rope_freq_base_train; float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; + float rope_yarn_log_mul; // for State Space Models uint32_t ssm_d_conv = 0; @@ -1827,6 +1888,12 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; + if (this->n_lora_q != other.n_lora_q) return true; + if (this->n_lora_kv != other.n_lora_kv) return true; + if (this->n_ff_exp != other.n_ff_exp) return true; + if (this->n_expert_shared != other.n_expert_shared) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -1842,6 +1909,8 @@ struct llama_hparams { if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true; if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; + if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true; + if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true; return false; } @@ -1917,6 +1986,8 @@ struct llama_layer { struct ggml_tensor * attn_k_norm_b; struct ggml_tensor * attn_out_norm; struct ggml_tensor * attn_out_norm_b; + struct ggml_tensor * attn_q_a_norm; + struct ggml_tensor * attn_kv_a_norm; // attention struct ggml_tensor * wq; @@ -1924,6 +1995,10 @@ struct llama_layer { struct ggml_tensor * wv; struct ggml_tensor * wo; struct ggml_tensor * wqkv; + struct ggml_tensor * wq_a; + struct ggml_tensor * wq_b; + struct ggml_tensor * wkv_a_mqa; + struct ggml_tensor * wkv_b; // attention bias struct ggml_tensor * bq; @@ -1957,8 +2032,9 @@ struct llama_layer { struct ggml_tensor * ffn_up_shexp; // ff bias - struct ggml_tensor * ffn_down_b; // b2 - struct ggml_tensor * ffn_up_b; // b3 + struct ggml_tensor * ffn_gate_b = nullptr; + struct ggml_tensor * ffn_down_b = nullptr; // b2 + struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act; // mamba proj @@ -2090,7 +2166,7 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; - std::unordered_map special_tokens_cache; + std::vector special_tokens_cache; std::map, int> bpe_ranks; @@ -3836,6 +3912,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; case MODEL_15B: return "15B"; + case MODEL_16B: return "16B"; case MODEL_20B: return "20B"; case MODEL_30B: return "30B"; case MODEL_34B: return "34B"; @@ -3843,6 +3920,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_40B: return "40B"; case MODEL_65B: return "65B"; case MODEL_70B: return "70B"; + case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; @@ -3985,7 +4063,9 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 22: model.type = e_model::MODEL_1B; break; case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B; break; + // granite uses a vocab with len 49152 + case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; + case 36: model.type = e_model::MODEL_8B; break; // granite case 40: model.type = e_model::MODEL_13B; break; case 48: model.type = e_model::MODEL_34B; break; case 60: model.type = e_model::MODEL_30B; break; @@ -4255,6 +4335,8 @@ static void llm_load_hparams( case 30: model.type = e_model::MODEL_3B; break; case 32: model.type = e_model::MODEL_7B; break; case 40: model.type = e_model::MODEL_15B; break; + case 52: model.type = e_model::MODEL_20B; break; // granite + case 88: model.type = e_model::MODEL_34B; break; // granite default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4388,6 +4470,26 @@ static void llm_load_hparams( model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_DEEPSEEK2: + { + bool is_lite = (hparams.n_layer == 27); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + if (!is_lite) { + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + switch (hparams.n_layer) { + case 27: model.type = e_model::MODEL_16B; break; + case 60: model.type = e_model::MODEL_236B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4497,6 +4599,11 @@ static void llm_load_vocab( } else { if (tokenizer_model == "gpt2") { vocab.type = LLAMA_VOCAB_TYPE_BPE; + + const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); + if (add_space_prefix_keyidx != -1) { + vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); + } } else { LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str()); LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); @@ -4728,97 +4835,19 @@ static void llm_load_vocab( // build special tokens cache { - // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type, - // and will always be correctly labeled in 'added_tokens.json' etc. - // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed - // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer - // are special tokens. - // From testing, this appears to correlate 1:1 with special tokens. - // - - // Counting special tokens and verifying in only one direction - // is sufficient to detect difference in those two sets. - // - uint32_t special_tokens_count_by_type = 0; - uint32_t special_tokens_count_from_verification = 0; - - bool special_tokens_definition_mismatch = false; - - for (const auto & t : vocab.token_to_id) { - const auto & token = t.first; - const auto & id = t.second; - - // Count all non-normal tokens in the vocab while iterating + for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { - special_tokens_count_by_type++; + vocab.special_tokens_cache.push_back(id); } + } - // Skip single character tokens - if (token.length() > 1) { - bool is_tokenizable = false; - - // Split token string representation in two, in all possible ways - // and check if both halves can be matched to a valid token - for (unsigned i = 1; i < token.length();) { - const auto left = token.substr(0, i); - const auto right = token.substr(i); - - // check if we didnt partition in the middle of a utf sequence - auto utf = utf8_len(left.at(left.length() - 1)); - - if (utf == 1) { - if (vocab.token_to_id.find(left) != vocab.token_to_id.end() && - vocab.token_to_id.find(right) != vocab.token_to_id.end() ) { - is_tokenizable = true; - break; - } - i++; - } else { - // skip over the rest of multibyte utf sequence - i += utf - 1; - } - } - - if (!is_tokenizable) { - // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1 - // it's faster to re-filter them here, since there are way less candidates now - - // Calculate a total "utf" length of a token string representation - size_t utf8_str_len = 0; - for (unsigned i = 0; i < token.length();) { - utf8_str_len++; - i += utf8_len(token.at(i)); - } - - // And skip the ones which are one character - if (utf8_str_len > 1) { - // At this point what we have left are special tokens only - vocab.special_tokens_cache[token] = id; - - // Count manually found special tokens - special_tokens_count_from_verification++; - - // If this manually found special token is not marked as such, flag a mismatch - if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) { - special_tokens_definition_mismatch = true; - } - } - } + std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(), + [&] (const llama_vocab::id a, const llama_vocab::id b) { + return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size(); } - } + ); - if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) { - LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n", - __func__, - special_tokens_count_from_verification, vocab.id_to_token.size(), - special_tokens_count_by_type, vocab.id_to_token.size() - ); - } else { - LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n", - __func__, - special_tokens_count_from_verification, vocab.id_to_token.size() - ); - } + LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size()); } } @@ -4899,6 +4928,16 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); } if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); } if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } + + if (model.arch == LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + } } // Returns false if cancelled by progress_callback @@ -5055,8 +5094,6 @@ static bool llm_load_tensors( throw std::runtime_error("model has expert layers but no expert layers are used"); } - GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); - ggml_context * ctx_input = ctx_map.at(model.buft_input.buft); ggml_context * ctx_output = ctx_map.at(model.buft_output.buft); ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix); @@ -5110,6 +5147,11 @@ static bool llm_load_tensors( layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + // optional MLP bias + layer.ffn_gate_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); @@ -6217,6 +6259,70 @@ static bool llm_load_tensors( layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } } break; + case LLM_ARCH_DEEPSEEK2: + { + bool is_lite = (hparams.n_layer == 27); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t q_lora_rank = hparams.n_lora_q; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + const uint32_t n_ff_exp = hparams.n_ff_exp; + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + if (!is_lite) { + layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); + } + layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); + + if (!is_lite) { + layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); + layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k}); + } else { + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + } + layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}); + layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + if ((uint32_t) i < hparams.n_layer_dense_lead) { + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } else { + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + GGML_ASSERT(hparams.n_expert > 0); + GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + + // Shared expert branch + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * hparams.n_expert_shared, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared}); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6671,6 +6777,8 @@ static struct ggml_tensor * llm_build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, + bool scale_w, + float w_scale, const llm_build_cb & cb, int il) { int64_t n_embd = cur->ne[0]; @@ -6702,6 +6810,10 @@ static struct ggml_tensor * llm_build_moe_ffn( weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); } + if (scale_w) { + weights = ggml_scale(ctx, weights, w_scale); + cb(weights, "ffn_moe_weights_scaled", il); + } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] @@ -7312,9 +7424,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -7332,6 +7444,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); } @@ -7813,6 +7926,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_GELU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -7956,6 +8070,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -9094,6 +9209,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, false, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -10981,6 +11097,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -11012,6 +11129,215 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_deepseek2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + bool is_lite = (hparams.n_layer == 27); + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); + const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + struct ggml_tensor * q = NULL; + if (!is_lite) { + // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = llm_build_norm(ctx0, q, hparams, + model.layers[il].attn_q_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(q, "q", il); + + // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + } else { + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + } + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0); + cb(q_nope, "q_nope", il); + // and {n_head * n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * n_embd_head_qk_nope); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(compressed_kv_pe, "compressed_kv_pe", il); + + // split into {kv_lora_rank, n_tokens} + struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0); + cb(compressed_kv, "compressed_kv", il); + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, n_embd_head_qk_rope, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank); + cb(k_pe, "k_pe", il); + + compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(compressed_kv, "compressed_kv", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_embd_head_qk_nope); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0); + cb(v_states, "v_states", il); + + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_rope_ext( + ctx0, ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + true, hparams.expert_weights_scale, + cb, il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up_shexp, NULL, + model.layers[il].ffn_gate_shexp, NULL, + model.layers[il].ffn_down_shexp, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -11230,6 +11556,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_arctic(); } break; + case LLM_ARCH_DEEPSEEK2: + { + result = llm.build_deepseek2(); + } break; default: GGML_ASSERT(false); } @@ -12742,7 +13072,7 @@ struct llm_tokenizer_wpm { llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} void tokenize(const std::string & text, std::vector & output) { - auto * token_map = &vocab.token_to_id; + const auto & token_map = vocab.token_to_id; // normalize and split by whitespace std::vector words = preprocess(text); @@ -12757,108 +13087,89 @@ struct llm_tokenizer_wpm { } // prepend phantom space - std::string word1 = "\xe2\x96\x81" + word; - int n = word1.size(); + const std::string word1 = "\xe2\x96\x81" + word; + const int n = word1.size(); - // we're at the start of a new word - int i = 0; - bool match_any = false; + const size_t current_tokens = output.size(); + // we're at the start of a new word // move through character position in word - while (i < n) { + for (int i = 0; i < n; ++i) { // loop through possible match length bool match = false; for (int j = n; j > i; j--) { - auto it = token_map->find(word1.substr(i, j - i)); - if (it != token_map->end()) { + auto it = token_map.find(word1.substr(i, j - i)); + if (it != token_map.end()) { output.push_back(it->second); match = true; - match_any = true; - i = j; + i = j - 1; break; } } - // must be an unknown character - if (!match) { - i++; + if (!match) { // discard all + output.resize(current_tokens); + break; // and discard next tokens } } // we didn't find any matches for this word - if (!match_any) { + if (current_tokens == output.size()) { output.push_back(vocab.special_unk_id); } } } std::vector preprocess(const std::string & text) { - std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); - - // strip accents, strip control, uniformize whitespace, - // to lowercase, pad chinese characters, pad punctuation - std::string new_str = ""; - for (uint32_t code : cpts_nfd) { - const codepoint_flags flags = unicode_cpt_flags(code); - if (flags.is_accent_mark || flags.is_control) { + const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + std::vector words(1, ""); + + for (const char32_t cpt : cpts_nfd) { + const auto flags = unicode_cpt_flags(cpt); + + if (flags.is_whitespace) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } continue; } - code = unicode_tolower(code); - if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ? - code = ' '; - } - std::string s = unicode_cpt_to_utf8(code); - if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) { - new_str += " "; - new_str += s; - new_str += " "; - } else { - new_str += s; + + assert (!flags.is_separator); + if (cpt == 0 || cpt == 0xFFFD || flags.is_control) { + continue; } - } - // split by whitespace - uint64_t l = 0; - uint64_t r = 0; - std::vector words; - while (r < new_str.size()) { - // if is whitespace - if (isspace(new_str[r], std::locale::classic())) { - if (r > l) words.push_back(new_str.substr(l, (r - l))); - l = r + 1; - r = l; + const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } + words.back() = s; // single char word + words.emplace_back(); // start a new word } else { - r += 1; + words.back() += s; // append char to word } } - if (r > l) { - words.push_back(new_str.substr(l, (r - l))); - } - return words; - } - bool is_ascii_punct(uint32_t code) { - if (code > 0xFF) { - return false; + if (!words.back().size()) { + words.pop_back(); } - auto c = char(static_cast(code)); - return ispunct(c, std::locale::classic()); + + return words; } - bool is_chinese_char(uint32_t cpt) { - if ((cpt >= 0x4E00 && cpt <= 0x9FFF) || - (cpt >= 0x3400 && cpt <= 0x4DBF) || + static bool is_chinese_char(uint32_t cpt) { + return + (cpt >= 0x04E00 && cpt <= 0x09FFF) || + (cpt >= 0x03400 && cpt <= 0x04DBF) || (cpt >= 0x20000 && cpt <= 0x2A6DF) || (cpt >= 0x2A700 && cpt <= 0x2B73F) || (cpt >= 0x2B740 && cpt <= 0x2B81F) || (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920 - (cpt >= 0xF900 && cpt <= 0xFAFF) || - (cpt >= 0x2F800 && cpt <= 0x2FA1F) || - (cpt >= 0x3000 && cpt <= 0x303F) || - (cpt >= 0xFF00 && cpt <= 0xFFEF)) { - return true; // NOLINT - } - return false; + (cpt >= 0x0F900 && cpt <= 0x0FAFF) || + (cpt >= 0x2F800 && cpt <= 0x2FA1F); + //(cpt >= 0x3000 && cpt <= 0x303F) || + //(cpt >= 0xFF00 && cpt <= 0xFFEF); } const llama_vocab & vocab; @@ -12902,9 +13213,8 @@ struct fragment_buffer_variant { static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) { // for each special token - for (const auto & st: vocab.special_tokens_cache) { - const auto & special_token = st.first; - const auto & special_id = st.second; + for (const llama_vocab::id special_id : vocab.special_tokens_cache) { + const auto & special_token = vocab.id_to_token[special_id].text; // for each text fragment std::forward_list::iterator it = buffer.begin(); @@ -12913,7 +13223,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // if a fragment is text ( not yet processed ) if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto * raw_text = &(fragment.raw_text); + auto & raw_text = fragment.raw_text; auto raw_text_base_offset = fragment.offset; auto raw_text_base_length = fragment.length; @@ -12923,7 +13233,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // find the first occurrence of a given special token in this fragment // passing offset argument only limit the "search area" but match coordinates // are still relative to the source full raw_text - auto match = raw_text->find(special_token, raw_text_base_offset); + auto match = raw_text.find(special_token, raw_text_base_offset); // no occurrences found, stop processing this fragment for a given special token if (match == std::string::npos) break; @@ -12942,7 +13252,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // left const int64_t left_reminder_offset = raw_text_base_offset + 0; const int64_t left_reminder_length = match - raw_text_base_offset; - buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length); + buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length); #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); @@ -12958,7 +13268,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { const int64_t right_reminder_offset = match + special_token.length(); const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); - buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length); + buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length); #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); @@ -16243,6 +16553,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_COMMAND_R: case LLM_ARCH_OLMO: case LLM_ARCH_ARCTIC: + case LLM_ARCH_DEEPSEEK2: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -17944,7 +18255,16 @@ static std::string llama_decode_text(const std::string & text) { const auto cpts = unicode_cpts_from_utf8(text); for (const auto cpt : cpts) { - decoded_text += unicode_utf8_to_byte(unicode_cpt_to_utf8(cpt)); + const auto utf8 = unicode_cpt_to_utf8(cpt); + try { + decoded_text += unicode_utf8_to_byte(utf8); + } catch (const std::out_of_range & e) { + decoded_text += "[UNK_BYTE_0x"; + for (const auto c : utf8) { + decoded_text += format("%02x", (uint8_t) c); + } + decoded_text += text + "]"; + } } return decoded_text; diff --git a/llama.h b/llama.h index 7671b8a57f4e7..3e4474bb94e9a 100644 --- a/llama.h +++ b/llama.h @@ -265,6 +265,8 @@ extern "C" { bool check_tensors; // validate model tensor data }; + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations + // https://github.com/ggerganov/llama.cpp/pull/7544 struct llama_context_params { uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model @@ -291,14 +293,14 @@ extern "C" { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; - enum ggml_type type_k; // data type for K cache - enum ggml_type type_v; // data type for V cache + enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] + enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] // Keep the booleans together to avoid misalignment during copy-by-value. bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU - bool flash_attn; // whether to use flash attention + bool flash_attn; // whether to use flash attention [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index de74585da29dd..b200ccccd51b0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1259,22 +1259,26 @@ struct test_im2col : public test_case { // GGML_OP_CONCAT struct test_concat : public test_case { const ggml_type type; - const std::array ne; - const int64_t b_ne2; + const std::array ne_a; + const int64_t ne_b_d; + const int dim; std::string vars() override { - return VARS_TO_STR3(type, ne, b_ne2); + return VARS_TO_STR4(type, ne_a, ne_b_d, dim); } test_concat(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}, - int64_t b_ne2 = 10) - : type(type), ne(ne), b_ne2(b_ne2) {} + std::array ne_a = {10, 10, 10, 10}, + int64_t ne_b_d = 10, + int dim = 2) + : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], b_ne2, ne[3]); - ggml_tensor * out = ggml_concat(ctx, a, b); + auto ne_b = ne_a; + ne_b[dim] = ne_b_d; + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_tensor * out = ggml_concat(ctx, a, b, dim); return out; } }; @@ -2211,8 +2215,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } - test_cases.emplace_back(new test_concat(GGML_TYPE_F32)); - test_cases.emplace_back(new test_concat(GGML_TYPE_I32)); + for (int dim : { 0, 1, 2, 3, }) { + test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim)); + test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim)); + } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); diff --git a/tests/test-tokenizer-0.sh b/tests/test-tokenizer-0.sh index 1fec8bbf130db..4d2b8365547df 100755 --- a/tests/test-tokenizer-0.sh +++ b/tests/test-tokenizer-0.sh @@ -28,6 +28,8 @@ printf "Tokenizing using (cpp) llama.cpp ...\n" cat /tmp/test-tokenizer-0-$name-py.log | grep "tokenized in" cat /tmp/test-tokenizer-0-$name-cpp.log | grep "tokenized in" +set +e + diff $input.tok $input.tokcpp > /dev/null 2>&1 if [ $? -eq 0 ]; then diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index 7e1b656e5f5fc..ec1b2837cfab5 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -167,8 +167,10 @@ def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]: for m in range(iterations): rand.seed(m) words = rand.choices(special_tokens, k=500) - if tokenizer.add_bos_token: # skip spam warning of double BOS - while words and words[0] == tokenizer.bos_token: + if words[0] == tokenizer.bos_token: # skip spam warning of double BOS + while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS + words.pop(0) + if tokenizer.add_bos_token: # drop all starting BOS words.pop(0) yield "".join(words) @@ -293,15 +295,17 @@ def main(argv: list[str] = None): model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096)) tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer) - tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", True) - tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", False) - def func_tokenize1(text: str): return model.tokenize(text, add_special=True, parse_special=True) def func_tokenize2(text: str): return tokenizer.encode(text, add_special_tokens=True) + ids = func_tokenize2("a") + assert 1 <= len(ids) <= 3 + add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0] + tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token) + vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True))) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text()) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases()) @@ -324,8 +328,10 @@ def func_tokenize2(text: str): # import os # tokenizers = os.listdir(path_tokenizers) tokenizers = [ - "llama-spm", # SPM - "phi-3", # SPM + # "llama-spm", # SPM + # "phi-3", # SPM + "jina-v2-en", # WPM + "bert-bge", # WPM ] for tokenizer in tokenizers: