From 95bc82fbc0df6d48cf66c857a4dda3d044f45ca2 Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Thu, 26 Sep 2024 17:38:31 +0800 Subject: [PATCH 01/28] [SYCL] add missed dll file in package (#9577) * update oneapi to 2024.2 * use 2024.1 --------- Co-authored-by: arthw <14088817+arthw@users.noreply.github.com> --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a54c5de99011c..e6a977b604d9b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -956,6 +956,7 @@ jobs: cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl7.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/svml_dispmd.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin echo "cp oneAPI running time dll files to ./build/bin done" 7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip ./build/bin/* From 44f59b4301c51f071daa2e951301bb17c14acc9b Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Fri, 27 Sep 2024 10:42:06 +0300 Subject: [PATCH 02/28] cmake : add option for common library (#9661) --- CMakeLists.txt | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 973907819d0d9..415743c2afe3f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,9 @@ option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) +# utils +option(LLAMA_BUILD_COMMON "llama: build common utils library" ON) + # extra artifacts option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -191,15 +194,17 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" DESTINATION lib/pkgconfig) # -# programs, examples and tests +# utils, programs, examples and tests # -add_subdirectory(common) +if (LLAMA_BUILD_COMMON) + add_subdirectory(common) +endif() if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) include(CTest) add_subdirectory(tests) -endif () +endif() if (LLAMA_BUILD_EXAMPLES) add_subdirectory(examples) From b5de3b74a595cbfefab7eeb5a567425c6a9690cf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Sep 2024 20:57:51 +0300 Subject: [PATCH 03/28] readme : update hot topics --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ce954f713815c..93225b63c319f 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,8 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ## Hot topics -- Huggingface GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) +- **Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669** +- Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) ---- From 89f9944981010d195e411a9fbfbb19959412f710 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Sat, 28 Sep 2024 12:05:05 +0200 Subject: [PATCH 04/28] Enable use to the rebar feature to upload buffers to the device. (#9251) --- ggml/src/ggml-vulkan.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index f9da45881e9df..a877145e82b49 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -1079,7 +1079,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { // Fall back to host memory type buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); } else { - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal); + // use rebar if available, otherwise fallback to device only visible memory + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); } } catch (const vk::SystemError& e) { std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; @@ -2806,7 +2807,11 @@ static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); - if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + + // If the device is not an UMA device the memory is host-accessible through rebar. While writing + // through PCIe is sufficient fast reading back data from PCIe is slower than going through + // the HW device to host copy path. + if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); memcpy(dst, (uint8_t *) src->ptr + offset, size); From 6a0f7794847244fb3b99a983e03137d7e832b585 Mon Sep 17 00:00:00 2001 From: Dan Johansson <164997844+eddnjjn@users.noreply.github.com> Date: Sat, 28 Sep 2024 14:06:16 +0200 Subject: [PATCH 05/28] ggml : add run-time detection of neon, i8mm and sve (#9331) * ggml: Added run-time detection of neon, i8mm and sve Adds run-time detection of the Arm instructions set features neon, i8mm and sve for Linux and Apple build targets. * ggml: Extend feature detection to include non aarch64 Arm arch * ggml: Move definition of ggml_arm_arch_features to the global data section --- ggml/include/ggml.h | 3 ++ ggml/src/ggml-aarch64.c | 13 +----- ggml/src/ggml-quants.c | 4 +- ggml/src/ggml-quants.h | 4 -- ggml/src/ggml.c | 101 ++++++++++++++++++++++++++++++++++------ 5 files changed, 93 insertions(+), 32 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index e24b8a319fc50..9f96e0c489b38 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2507,6 +2507,9 @@ extern "C" { GGML_API int ggml_cpu_has_cann (void); GGML_API int ggml_cpu_has_llamafile (void); + // get the sve vector length in bytes + GGML_API int ggml_cpu_get_sve_cnt(void); + // // Internal types and functions exposed for tests and benchmarks // diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 8912de63d9252..b27f411474f4c 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -598,15 +598,6 @@ size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_ return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); } -// Return the number of byte lanes in the SVE vector if SVE is supported; otherwise, returns 0 if SVE is not supported. -static int sve_lane_count(void) { -#if defined(__ARM_FEATURE_SVE) - return ggml_sve_cnt_b; -#else - return 0; -#endif -} - void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -843,7 +834,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) #if defined(__ARM_FEATURE_SVE) - if (ggml_cpu_has_sve() && sve_lane_count() == QK8_0) { + if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) { const void * b_ptr = vx; const void * a_ptr = vy; float * res_ptr = s; @@ -2020,7 +2011,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && sve_lane_count() == QK8_0) { + if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { const void * b_ptr = vx; const void * a_ptr = vy; float * res_ptr = s; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 8bffce860a1eb..7aa6dce8907f5 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -4013,7 +4013,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumv0 = svdup_n_f32(0.0f); svfloat32_t sumv1 = svdup_n_f32(0.0f); - const int vector_length = ggml_sve_cnt_b*8; + const int vector_length = ggml_cpu_get_sve_cnt()*8; // VLA Implementation using switch case switch (vector_length) { @@ -5597,7 +5597,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumv0 = svdup_n_f32(0.0f); svfloat32_t sumv1 = svdup_n_f32(0.0f); - const int vector_length = ggml_sve_cnt_b*8; + const int vector_length = ggml_cpu_get_sve_cnt()*8; //VLA Implemenation for SVE switch (vector_length) { diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index e96ce2b5e5c4e..df9c4b24ae74f 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -142,10 +142,6 @@ void iq2xs_free_impl(enum ggml_type type); void iq3xs_init_impl(int grid_size); void iq3xs_free_impl(int grid_size); -#if defined(__ARM_FEATURE_SVE) -extern int ggml_sve_cnt_b; -#endif - #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4b782b0c13550..fac4466e31d44 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -39,9 +39,6 @@ #include #endif -#if defined(__ARM_FEATURE_SVE) -int ggml_sve_cnt_b = 0; -#endif #if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) #undef GGML_USE_LLAMAFILE #endif @@ -455,6 +452,15 @@ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) (ggml-impl.h) float ggml_table_f32_f16[1 << 16]; +#if defined(__ARM_ARCH) +struct ggml_arm_arch_features_type { + int has_neon; + int has_i8mm; + int has_sve; + int sve_cnt; +} ggml_arm_arch_features = {-1, -1, -1, 0}; +#endif + GGML_CALL const char * ggml_status_to_string(enum ggml_status status) { switch (status) { case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)"; @@ -3673,6 +3679,66 @@ static inline int ggml_up(int n, int m) { //////////////////////////////////////////////////////////////////////////////// +#if defined(__ARM_ARCH) + +#if defined(__linux__) && defined(__aarch64__) +#include +#elif defined(__APPLE__) +#include +#endif + +static void ggml_init_arm_arch_features(void) { +#if defined(__linux__) && defined(__aarch64__) + uint32_t hwcap = getauxval(AT_HWCAP); + uint32_t hwcap2 = getauxval(AT_HWCAP2); + + ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); + ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); + ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + +#if defined(__ARM_FEATURE_SVE) + ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#endif +#elif defined(__APPLE__) + int oldp = 0; + size_t size = sizeof(oldp); + if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_neon = oldp; + + if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_i8mm = oldp; + + ggml_arm_arch_features.has_sve = 0; + ggml_arm_arch_features.sve_cnt = 0; +#else +// Run-time CPU feature detection not implemented for this platform, fallback to compile time +#if defined(__ARM_NEON) + ggml_arm_arch_features.has_neon = 1; +#else + ggml_arm_arch_features.has_neon = 0; +#endif + +#if defined(__ARM_FEATURE_MATMUL_INT8) + ggml_arm_arch_features.has_i8mm = 1; +#else + ggml_arm_arch_features.has_i8mm = 0; +#endif + +#if defined(__ARM_FEATURE_SVE) + ggml_arm_arch_features.has_sve = 1; + ggml_arm_arch_features.sve_cnt = 16; +#else + ggml_arm_arch_features.has_sve = 0; + ggml_arm_arch_features.sve_cnt = 0; +#endif +#endif +} +#endif + struct ggml_context * ggml_init(struct ggml_init_params params) { // make this function thread safe ggml_critical_section_start(); @@ -3723,6 +3789,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } +#if defined(__ARM_ARCH) + ggml_init_arm_arch_features(); +#endif + is_first_call = false; } @@ -3771,12 +3841,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_ASSERT_ALIGNED(ctx->mem_buffer); -#if defined(__ARM_FEATURE_SVE) - if (!ggml_sve_cnt_b) { - ggml_sve_cnt_b = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); - } -#endif - GGML_PRINT_DEBUG("%s: context initialized\n", __func__); ggml_critical_section_end(); @@ -23578,16 +23642,16 @@ int ggml_cpu_has_fma(void) { } int ggml_cpu_has_neon(void) { -#if defined(__ARM_NEON) - return 1; +#if defined(__ARM_ARCH) + return ggml_arm_arch_features.has_neon; #else return 0; #endif } int ggml_cpu_has_sve(void) { -#if defined(__ARM_FEATURE_SVE) - return 1; +#if defined(__ARM_ARCH) + return ggml_arm_arch_features.has_sve; #else return 0; #endif @@ -23734,11 +23798,18 @@ int ggml_cpu_has_vsx(void) { } int ggml_cpu_has_matmul_int8(void) { -#if defined(__ARM_FEATURE_MATMUL_INT8) - return 1; +#if defined(__ARM_ARCH) + return ggml_arm_arch_features.has_i8mm; #else return 0; #endif } +int ggml_cpu_get_sve_cnt(void) { +#if defined(__ARM_ARCH) + return ggml_arm_arch_features.sve_cnt; +#else + return 0; +#endif +} //////////////////////////////////////////////////////////////////////////////// From 43bcdd9703ec19af7d2a519640b5ed6f4aac3d53 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 28 Sep 2024 15:07:14 +0300 Subject: [PATCH 06/28] readme : add tool (#9655) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 93225b63c319f..a452a6d786948 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,7 @@ Unless otherwise noted these projects are open-source with permissive licensing: **Tools:** - [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML +- [akx/ollama-dl](https://github.com/akx/ollama-dl) – download models from the Ollama library to be used directly with llama.cpp - [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption - [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage - [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with prebuild Mobile and Web platform wrappers and a model example) From 9a913110cf471a8287ac06c43cbe307d3cf6df99 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Sat, 28 Sep 2024 12:08:43 +0000 Subject: [PATCH 07/28] llama : add support for Chameleon (#8543) * convert chameleon hf to gguf * add chameleon tokenizer tests * fix lint * implement chameleon graph * add swin norm param * return qk norm weights and biases to original format * implement swin norm * suppress image token output * rem tabs * add comment to conversion * fix ci * check for k norm separately * adapt to new lora implementation * fix layer input for swin norm * move swin_norm in gguf writer * add comment regarding special token regex in chameleon pre-tokenizer * Update src/llama.cpp Co-authored-by: compilade * fix punctuation regex in chameleon pre-tokenizer (@compilade) Co-authored-by: compilade * fix lint * trigger ci --------- Co-authored-by: compilade --- convert_hf_to_gguf.py | 44 +++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 19 ++ gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 4 +- include/llama.h | 1 + models/ggml-vocab-chameleon.gguf.inp | 112 ++++++++++++ models/ggml-vocab-chameleon.gguf.out | 46 +++++ src/llama-vocab.cpp | 14 ++ src/llama.cpp | 263 +++++++++++++++++++++++++++ 10 files changed, 505 insertions(+), 2 deletions(-) create mode 100644 models/ggml-vocab-chameleon.gguf.inp create mode 100644 models/ggml-vocab-chameleon.gguf.out diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7be609054d6b8..2cd5a8c11bc18 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -640,6 +640,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085": # ref: https://huggingface.co/microsoft/phi-2 res = "phi-2" + if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450": + # ref: https://huggingface.co/facebook/chameleon-7b + res = "chameleon" if res is None: logger.warning("\n") @@ -4138,6 +4141,47 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +@Model.register("ChameleonForCausalLM") +class ChameleonModel(Model): + model_arch = gguf.MODEL_ARCH.CHAMELEON + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False)) + + def set_vocab(self): + self._set_vocab_gpt2() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # ignore image tokenizer for now + # TODO: remove this once image support is implemented for Chameleon + if name.startswith("model.vqmodel"): + return [] + + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + hidden_dim = self.hparams.get("hidden_size") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + if name.endswith(("q_norm.weight", "q_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim) + if name.endswith(("k_norm.weight", "k_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim) + + return [(self.map_tensor_name(name), data_torch)] + + # see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203 + @staticmethod + def _reverse_hf_permute(data_torch, n_heads, hidden_dim): + head_dim = hidden_dim // n_heads + data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1) + data_torch = data_torch.repeat_interleave(n_heads, 0) + return data_torch + + ###### CONVERSION LOGIC ###### diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 021f65abdc45d..4d11059f374d2 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -99,6 +99,7 @@ class TOKENIZER_TYPE(IntEnum): {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, + {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 560eee916f27e..2fd2e9d2be828 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -94,6 +94,7 @@ class LLM: DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + SWIN_NORM = "{arch}.swin_norm" RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim" TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" @@ -236,6 +237,7 @@ class MODEL_ARCH(IntEnum): EXAONE = auto() GRANITE = auto() GRANITE_MOE = auto() + CHAMELEON = auto() class MODEL_TENSOR(IntEnum): @@ -394,6 +396,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.EXAONE: "exaone", MODEL_ARCH.GRANITE: "granite", MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.CHAMELEON: "chameleon", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1260,6 +1263,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.CHAMELEON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index bd059b45c64d0..5c460ef1bc260 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -670,6 +670,9 @@ def add_expert_shared_count(self, count: int) -> None: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_swin_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) + def add_rescale_every_n_layers(self, count: int) -> None: self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4e850726e9ba4..5ef91f11d312f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -380,7 +380,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", "model.layers.{bid}.self_attn.q_layernorm", # persimmon - "model.layers.{bid}.self_attn.q_norm", # cohere olmoe + "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm @@ -389,7 +389,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", "model.layers.{bid}.self_attn.k_layernorm", # persimmon - "model.layers.{bid}.self_attn.k_norm", # cohere olmoe + "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm diff --git a/include/llama.h b/include/llama.h index 132937a0700e7..caef0bfff0b7d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -102,6 +102,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, }; enum llama_rope_type { diff --git a/models/ggml-vocab-chameleon.gguf.inp b/models/ggml-vocab-chameleon.gguf.inp new file mode 100644 index 0000000000000..9baf7d77ae6b5 --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-chameleon.gguf.out b/models/ggml-vocab-chameleon.gguf.out new file mode 100644 index 0000000000000..7c5413fee0adf --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.out @@ -0,0 +1,46 @@ + 17245 16604 16403 16604 33583 18355 + 16421 51153 + + 16604 + 16650 + 16650 16604 + 16581 + 16582 + 16582 16582 + 16582 16582 16582 + 16581 16582 + 31596 17394 + 34926 17394 + 31596 18671 + 34926 18671 + 34926 18671 16384 + 31596 16395 17394 16384 + 34926 16395 17394 16384 + 16811 16704 20410 16483 16631 16397 52854 + 16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639 + 29479 23955 17012 20103 25527 27670 17408 19005 21473 24774 + 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615 + 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392 + 31596 + 34926 + 16650 31596 + 16650 34926 + 16696 31596 + 16696 31596 16582 16696 31596 + 16604 16391 + 16582 16604 16412 + 16390 22623 + 31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 + 16384 16384 16384 16384 16384 16384 + 16402 + 16402 16402 + 16402 16402 16402 + 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 16402 + 16418 19038 16639 16448 24315 33727 16467 + 18765 17981 + 16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a771eccda3017..146d416f770f2 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -450,6 +450,20 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). + regex_exprs = { + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "([\\t\\n]| | )", // directly from tokenizer.json + "\\p{N}", // Individual digits + "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { diff --git a/src/llama.cpp b/src/llama.cpp index 0accb1492efaa..f450eaf9ddc6f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -216,6 +216,7 @@ enum llm_arch { LLM_ARCH_RWKV6, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, + LLM_ARCH_CHAMELEON, LLM_ARCH_UNKNOWN, }; @@ -268,6 +269,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -304,6 +306,7 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, LLM_KV_RESCALE_EVERY_N_LAYERS, LLM_KV_TIME_MIX_EXTRA_DIM, LLM_KV_TIME_DECAY_EXTRA_DIM, @@ -411,6 +414,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_SWIN_NORM, "%s.swin_norm" }, { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" }, { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, @@ -1499,6 +1503,25 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_CHAMELEON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2362,6 +2385,7 @@ struct llama_hparams { bool vocab_only; bool rope_finetuned; bool use_par_res; + bool swin_norm; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on @@ -6084,6 +6108,18 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CHAMELEON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_34B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6341,6 +6377,11 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "exaone") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "chameleon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -8728,6 +8769,45 @@ static bool llm_load_tensors( } } break; + case LLM_ARCH_CHAMELEON: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -15872,6 +15952,184 @@ struct llm_build_context { return gf; } + + // ref: https://github.com/facebookresearch/chameleon + // based on the original build_llama() function, changes: + // * qk-norm + // * swin-norm + // * removed bias + // * removed MoE + struct ggml_cgraph * build_chameleon() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + if (hparams.swin_norm) { + cur = inpL; + } else { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + } + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + if (model.layers[il].attn_q_norm) { + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur) * n_embd_head, + ggml_element_size(Qcur) * n_embd_head * n_head, + 0); + cb(Qcur, "Qcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, cb, il); + cb(Qcur, "Qcur", il); + } + + if (model.layers[il].attn_k_norm) { + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, cb, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + if (hparams.swin_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + } + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (!hparams.swin_norm) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + } + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + if (hparams.swin_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output_with_img_logits", -1); + + // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. + // Needs to be removed once image outputs are supported. + int img_token_end_idx = 8196; + int img_token_start_idx = 4; + int num_img_tokens = img_token_end_idx - img_token_start_idx; + // creates 1d tensor of size num_img_tokens and values -FLT_MAX, + // which ensures that text token values are always at least larger than image token values + struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens); + img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX); + cb(img_logits, "img_logits", -1); + cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -16132,6 +16390,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_rwkv6(); } break; + case LLM_ARCH_CHAMELEON: + { + result = llm.build_chameleon(); + } break; default: GGML_ABORT("fatal error"); } @@ -19257,6 +19519,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_CHAMELEON: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From 6102037bbb55521880ae78a6ee6c2a0c00c901df Mon Sep 17 00:00:00 2001 From: Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com> Date: Sat, 28 Sep 2024 20:10:58 +0800 Subject: [PATCH 08/28] vocab : refactor tokenizer to reduce init overhead (#9449) * refactor tokenizer * llama : make llm_tokenizer more private ggml-ci * refactor tokenizer * refactor tokenizer * llama : make llm_tokenizer more private ggml-ci * remove unused files * remove unused fileds to avoid unused filed build error * avoid symbol link error * Update src/llama.cpp * Update src/llama.cpp --------- Co-authored-by: Georgi Gerganov --- .../convert-llama2c-to-ggml.cpp | 14 +- src/llama-vocab.cpp | 266 +++++++++++------- src/llama-vocab.h | 9 + src/llama.cpp | 2 + tests/test-tokenizer-0.cpp | 90 +++--- 5 files changed, 239 insertions(+), 142 deletions(-) diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index ecff95f9a69de..c140daed3c056 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -201,7 +201,7 @@ static void print_sample_weights(TransformerWeights *w){ //////////////////////////////////////// ggml structs and functions required to load models, configs and save the model. -struct llama_vocab { +struct my_llama_vocab { using id = int32_t; using token = std::string; using ttype = llama_token_type; @@ -525,7 +525,7 @@ static std::string llama_escape_whitespaces(const std::string & text) { return out.str(); } -static void load_vocab(const char * filename, const Config * config, struct llama_vocab * vocab) { +static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) { if (is_ggml_file(filename)) { LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename); struct ggml_context * ctx_data = NULL; @@ -583,13 +583,13 @@ static void load_vocab(const char * filename, const Config * config, struct llam const int n_vocab = config->vocab_size; /* uint32_t max_token_length = */ file.read_u32(); // unused vocab->id_to_token.resize(n_vocab); - for (llama_vocab::id id=0; idtoken_embedding_table -> model->tok_embeddings @@ -671,7 +671,7 @@ static void save_as_llama_model( std::vector tokens; std::vector scores; std::vector token_types; - for (const llama_vocab::token_data & token_data : vocab->id_to_token) { + for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) { tokens.push_back(token_data.text.c_str()); scores.push_back(token_data.score); token_types.push_back(token_data.type); @@ -905,7 +905,7 @@ int main(int argc, char ** argv) { fclose(file); } - struct llama_vocab vocab; + struct my_llama_vocab vocab; load_vocab(params.fn_vocab_model, &config, &vocab); struct my_llama_model model; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 146d416f770f2..e4d844a73c216 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -50,7 +50,7 @@ struct naive_trie { res.first->second.insert(key + 1, len - 1, value); } } - std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) { + std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) const { if (len == 0 || offset == len) { return std::make_pair(key, offset); } @@ -79,6 +79,15 @@ struct naive_trie { // impl // +struct llm_tokenizer { + llm_tokenizer() {} + virtual ~llm_tokenizer() = default; +}; + +llama_vocab::~llama_vocab() { + delete tokenizer; +} + int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { GGML_ASSERT(token_left.find(' ') == std::string::npos); GGML_ASSERT(token_left.find('\n') == std::string::npos); @@ -187,10 +196,15 @@ struct llm_bigram_spm { size_t size; }; -struct llm_tokenizer_spm { - llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {} +struct llm_tokenizer_spm : llm_tokenizer { + llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {} +}; + +struct llm_tokenizer_spm_session { + llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {} void tokenize(const std::string & text, std::vector & output) { + // split string into utf8 chars int index = 0; size_t offs = 0; @@ -271,7 +285,7 @@ struct llm_tokenizer_spm { return; } - resegment(symbols[p->second.first], output); + resegment(symbols[p->second.first], output); resegment(symbols[p->second.second], output); } @@ -279,7 +293,6 @@ struct llm_tokenizer_spm { if (left == -1 || right == -1) { return; } - const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); auto token = vocab.token_to_id.find(text); @@ -306,10 +319,11 @@ struct llm_tokenizer_spm { } const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_spm * spm_tokenizer; std::vector symbols; llm_bigram_spm::queue work_queue; - std::map> rev_merge; }; @@ -352,8 +366,8 @@ struct llm_bigram_bpe { size_t size; }; -struct llm_tokenizer_bpe { - llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) { +struct llm_tokenizer_bpe : llm_tokenizer { + llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() { GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE); switch (vocab.type_pre) { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: @@ -476,7 +490,14 @@ struct llm_tokenizer_bpe { } } - void append(const llama_vocab::id token_id, std::vector & output) const { + std::vector regex_exprs; +}; + +struct llm_tokenizer_bpe_session { + llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab), + bpe_tokenizer(static_cast(vocab.tokenizer)) {} + + static void append(const llama_vocab::id token_id, std::vector & output) { output.push_back(token_id); } @@ -515,12 +536,11 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - - const auto word_collection = unicode_regex_split(text, regex_exprs); + const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs); symbols_final.clear(); - for (auto & word : word_collection) { + for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); symbols.clear(); @@ -623,7 +643,6 @@ struct llm_tokenizer_bpe { if (left == -1 || right == -1) { return; } - std::string left_token = std::string(symbols[left].text, symbols[left].n); std::string right_token = std::string(symbols[right].text, symbols[right].n); @@ -647,12 +666,10 @@ struct llm_tokenizer_bpe { } const llama_vocab & vocab; - - std::vector regex_exprs; + const llm_tokenizer_bpe * bpe_tokenizer; std::vector symbols; std::vector symbols_final; - llm_bigram_bpe::queue work_queue; }; @@ -660,15 +677,17 @@ struct llm_tokenizer_bpe { // WPM tokenizer // -struct llm_tokenizer_wpm { - llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} +struct llm_tokenizer_wpm : llm_tokenizer { + llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {} +}; - void tokenize(const std::string & text, std::vector & output) const { - const auto & token_map = vocab.token_to_id; +struct llm_tokenizer_wpm_session { + llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {} + void tokenize(const std::string & text, std::vector & output) { + const auto & token_map = vocab.token_to_id; // normalize and split by whitespace std::vector words = preprocess(text); - // bos token prepended already // find the longest tokens that form the words @@ -713,7 +732,7 @@ struct llm_tokenizer_wpm { } // TODO: reduce string copies by using cpts_offs array - std::vector preprocess(const std::string & text) const { + static std::vector preprocess(const std::string & text) { const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector words(1, ""); @@ -765,15 +784,18 @@ struct llm_tokenizer_wpm { //(cpt >= 0xFF00 && cpt <= 0xFFEF); } +private: const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_wpm * wpm_tokenizer; }; // // UGM tokenizer // -struct llm_tokenizer_ugm { - llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) { +struct llm_tokenizer_ugm : llm_tokenizer { + llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() { if (vocab.precompiled_charsmap.size() > 0) { size_t charsmap_offset = 0; @@ -819,6 +841,30 @@ struct llm_tokenizer_ugm { unknown_token_score = min_score - unknown_token_score_penalty; } + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + const char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; + + const uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; + + struct naive_trie user_defined_token_matcher; + + float min_score = FLT_MAX; + float max_score = -FLT_MAX; + + float unknown_token_score_penalty = 10.0; + float unknown_token_score; + + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_ugm_session { + llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab), + ugm_tokenizer(static_cast(vocab.tokenizer)) {} + /* This implementation is based on SentencePiece optimized Viterbi algorithm for * unigram language models. The general idea is to: * - move along the input sequence in steps of one UTF code point, @@ -857,7 +903,7 @@ struct llm_tokenizer_ugm { // traverse the token matcher trie to find a matching token bool single_codepoint_token_found = false; const struct best_tokenization & current_best = tokenization_results[input_offset]; - const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]); while (prefix_offset <= input_len && node != NULL) { // check if we found valid token in prefix @@ -887,7 +933,7 @@ struct llm_tokenizer_ugm { // if we didn't find a valid token corresponding to the whole UTF code point // then use unknown token as the tokenization of this UTF code point if (!single_codepoint_token_found) { - const double challenger_score = current_best.score_sum + unknown_token_score; + const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score; prefix_offset = input_offset + n_utf8_code_units; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { @@ -919,7 +965,6 @@ struct llm_tokenizer_ugm { } private: - const llama_vocab & vocab; // helper structure for returning normalization results struct normalization_result { @@ -932,7 +977,7 @@ struct llm_tokenizer_ugm { normalized->clear(); normalized->reserve(input.size() * 3); - const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; + const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " "; bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; @@ -1014,13 +1059,21 @@ struct llm_tokenizer_ugm { size_t xcda_array_size; }; + // this structure stores the best tokenization so far at input_offset + struct best_tokenization { + llama_token token_id; + size_t input_offset; + float score_sum; + }; + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { if (input_offset == input.size()) { return { &input[input_offset], 0, 0 }; } // if input prefix matches some user-defined token return this token as normalization result - auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + auto user_defined_token_match = + ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); if (user_defined_token_match.second > 0) { return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; } @@ -1028,8 +1081,8 @@ struct llm_tokenizer_ugm { size_t longest_prefix_length = 0; size_t longest_prefix_offset = 0; - if (xcda_array_size > 0) { - struct xcda_array_view xcda_view(xcda_array, xcda_array_size); + if (ugm_tokenizer->xcda_array_size > 0) { + struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size); // Find the longest normalized sequence matching the input prefix by walking // the XOR-compressed compact double array (XCDA) starting from the root node @@ -1065,50 +1118,27 @@ struct llm_tokenizer_ugm { if (longest_prefix_length > 0) { // we have a match, so return the replacement sequence - if (longest_prefix_offset >= prefix_replacements_size) { + if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) { throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); } - const char * prefix_replacement = &prefix_replacements[longest_prefix_offset]; + const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset]; return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; - } else { - // check if the input prefix contains a valid sequence of UTF-8 code units - try { - // if yes, return this sequence unmodified - size_t prefix_offset = input_offset; - unicode_cpt_from_utf8(input, prefix_offset); - return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; - } catch (std::invalid_argument & /*ex*/) { - // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER - return { "\xEF\xBF\xBD", 3, 1 }; - } } - } - - // escaped space symbol - U+2581 (Lower One Eighth Block) - const std::string escaped_space = "\xE2\x96\x81"; - - const char * prefix_replacements = NULL; - size_t prefix_replacements_size = 0; - - const uint32_t * xcda_array = NULL; - size_t xcda_array_size = 0; - struct naive_trie user_defined_token_matcher; - - // this structure stores the best tokenization so far at input_offset - struct best_tokenization { - llama_token token_id; - size_t input_offset; - float score_sum; - }; - - float min_score = FLT_MAX; - float max_score = -FLT_MAX; - - float unknown_token_score_penalty = 10.0; - float unknown_token_score; + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch (std::invalid_argument & /*ex*/) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; + } + } - struct naive_trie token_matcher; + const llama_vocab & vocab; + const llm_tokenizer_ugm * ugm_tokenizer; }; // @@ -1169,8 +1199,8 @@ static std::vector llama_unescape_rwkv_token(const std::string & escape return output; } -struct llm_tokenizer_rwkv { - llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) { +struct llm_tokenizer_rwkv : llm_tokenizer { + llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() { // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. // For now, we decode the vocab here into the lookup we'll use for tokenization. @@ -1182,11 +1212,17 @@ struct llm_tokenizer_rwkv { } } + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_rwkv_session { + llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab), + rwkv_tokenizer(static_cast(*vocab.tokenizer)) {} + void tokenize(const std::string & text, std::vector & output) { uint32_t position = 0; - while (position < text.size()) { - const struct naive_trie * node = token_matcher.traverse(text[position]); + const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]); if (node == NULL) { // no matching token found, add unknown token output.push_back(vocab.special_unk_id); @@ -1211,11 +1247,33 @@ struct llm_tokenizer_rwkv { } } +private: const llama_vocab & vocab; - - struct naive_trie token_matcher; + const llm_tokenizer_rwkv & rwkv_tokenizer; }; +void llama_vocab::init_tokenizer() { + switch (type) { + case LLAMA_VOCAB_TYPE_SPM: + tokenizer = new llm_tokenizer_spm(*this); + break; + case LLAMA_VOCAB_TYPE_BPE: + tokenizer = new llm_tokenizer_bpe(*this); + break; + case LLAMA_VOCAB_TYPE_WPM: + tokenizer = new llm_tokenizer_wpm(*this); + break; + case LLAMA_VOCAB_TYPE_UGM: + tokenizer = new llm_tokenizer_ugm(*this); + break; + case LLAMA_VOCAB_TYPE_RWKV: + tokenizer = new llm_tokenizer_rwkv(*this); + break; + default: + GGML_ABORT("unsupported vocab type"); + } +} + // // (de-) tokenize // @@ -1277,7 +1335,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // if a fragment is text ( not yet processed ) if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto & raw_text = fragment.raw_text; + const auto & raw_text = fragment.raw_text; auto raw_text_base_offset = fragment.offset; auto raw_text_base_length = fragment.length; @@ -1376,7 +1434,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< } } -std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) { +std::vector llama_tokenize_internal( + const llama_vocab & vocab, + std::string raw_text, + bool add_special, + bool parse_special) { + GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first."); + std::vector output; std::forward_list fragment_buffer; @@ -1413,9 +1477,9 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - llm_tokenizer_spm tokenizer(vocab); llama_escape_whitespace(raw_text); - tokenizer.tokenize(raw_text, output); + llm_tokenizer_spm_session session(vocab); + session.tokenize(raw_text, output); is_prev_special = false; } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); @@ -1437,10 +1501,11 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe tokenizer(vocab); - + llm_tokenizer_bpe_session session(vocab); + // it calls some other methods that are not exist in llm_tokenizer, + // here just cast it to bpe tokenizer object if (add_special) { - tokenizer.append_bos(output); + session.append_bos(output); } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -1449,15 +1514,15 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - tokenizer.append(fragment.token, output); + session.append(fragment.token, output); } } if (add_special) { - tokenizer.append_eos(output); - tokenizer.check_double_bos_eos(output); + session.append_eos(output); + session.check_double_bos_eos(output); } } break; case LLAMA_VOCAB_TYPE_WPM: @@ -1467,7 +1532,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, output.push_back(vocab.special_cls_id); } - llm_tokenizer_wpm tokenizer(vocab); + llm_tokenizer_wpm_session session(vocab); for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -1476,7 +1541,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } @@ -1489,12 +1554,11 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, } break; case LLAMA_VOCAB_TYPE_UGM: { - llm_tokenizer_ugm tokenizer(vocab); - if (add_special && vocab.tokenizer_add_bos != 0) { GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); } + llm_tokenizer_ugm_session session(vocab); for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -1502,7 +1566,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } @@ -1522,6 +1586,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, } break; case LLAMA_VOCAB_TYPE_RWKV: { + llm_tokenizer_rwkv_session session(vocab); for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -1530,8 +1595,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - llm_tokenizer_rwkv tokenizer(vocab); - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } @@ -1644,13 +1708,13 @@ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) { } int32_t llama_tokenize_impl( - const struct llama_vocab & vocab, - const char * text, - int32_t text_len, - llama_token * tokens, - int32_t n_tokens_max, - bool add_special, - bool parse_special) { + const struct llama_vocab & vocab, + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special) { auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); @@ -1775,6 +1839,8 @@ int32_t llama_detokenize_impl( int32_t text_len_max, bool remove_special, bool unparse_special) { + GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first."); + int32_t avail = text_len_max; int32_t total = 0; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index cc46f642bf1ae..069bdc423a60b 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -8,6 +8,8 @@ #include #include +struct llm_tokenizer; + struct llama_vocab { using id = llama_token; using token = std::string; @@ -65,7 +67,14 @@ struct llama_vocab { std::vector precompiled_charsmap; + llm_tokenizer * tokenizer = nullptr; + + llama_vocab() = default; + ~llama_vocab(); + int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; + + void init_tokenizer(); }; // diff --git a/src/llama.cpp b/src/llama.cpp index f450eaf9ddc6f..44afb31d74e53 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6464,6 +6464,8 @@ static void llm_load_vocab( } GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); + vocab.init_tokenizer(); + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { // For Fill-In-the-Middle (FIM)/infill models which where converted diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index d3d21331bfd3d..4d49850c9ea25 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -7,6 +7,7 @@ #include #include #include +#include //static const std::map> & k_tests() { // static std::map> _k_tests = { @@ -194,45 +195,64 @@ int main(int argc, char **argv) { const bool add_special = false; - for (const auto & test_kv : k_tests) { - const std::vector res = llama_tokenize(ctx, test_kv.first, add_special, false); - - printf("\n"); - printf("src: '%s'\n", test_kv.first.c_str()); - printf("res: '%s'\n", llama_detokenize(ctx, res).c_str()); - printf("tok: "); - for (const auto & tok : res) { - printf("%d ", tok); - } - printf("\n"); - - bool correct = res.size() == test_kv.second.size(); - for (int i = 0; i < (int) res.size() && correct; ++i) { - if (test_kv.second[i] != res[i]) { - correct = false; + // multi-threaded tokenization + const int nthread = std::thread::hardware_concurrency(); + std::vector threads(nthread); + + for (int i = 0; i < nthread; i++) { + threads[i] = std::thread([&, i]() { + for (const auto & test_kv : k_tests) { + const std::vector res = llama_tokenize(ctx, test_kv.first, add_special, false); + + // here only print the result of the first thread + // because the other threads are running the same tests + if (i != 0) { + continue; + } + + printf("\n"); + printf("src: '%s'\n", test_kv.first.c_str()); + printf("res: '%s'\n", llama_detokenize(ctx, res).c_str()); + printf("tok: "); + for (const auto & tok : res) { + printf("%d ", tok); + } + printf("\n"); + + bool correct = res.size() == test_kv.second.size(); + for (int i = 0; i < (int) res.size() && correct; ++i) { + if (test_kv.second[i] != res[i]) { + correct = false; + } + } + + if (!correct) { + fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); + fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, + llama_detokenize(ctx, res).c_str(), + llama_detokenize(ctx, test_kv.second).c_str()); + fprintf(stderr, "%s : expected tokens: ", __func__); + for (const auto & t : test_kv.second) { + fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); + } + fprintf(stderr, "\n"); + fprintf(stderr, "%s : got tokens: ", __func__); + for (const auto & t : res) { + fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); + } + fprintf(stderr, "\n"); + + success = false; + } } - } - - if (!correct) { - fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); - fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, - llama_detokenize(ctx, res).c_str(), - llama_detokenize(ctx, test_kv.second).c_str()); - fprintf(stderr, "%s : expected tokens: ", __func__); - for (const auto & t : test_kv.second) { - fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); - } - fprintf(stderr, "\n"); - fprintf(stderr, "%s : got tokens: ", __func__); - for (const auto & t : res) { - fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); - } - fprintf(stderr, "\n"); + }); + } - success = false; - } + for (int i = 0; i < nthread; i++) { + threads[i].join(); } + // single threaded tokenization if (!fname_text.empty()) { fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str()); From 739842703e32cd43443c45e0b4f6647cc4e6b3d6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 28 Sep 2024 15:13:21 +0300 Subject: [PATCH 09/28] llama : add comment about thread-safety [no ci] (#9449) --- include/llama.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/llama.h b/include/llama.h index caef0bfff0b7d..4ea8a2c2b664b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -911,6 +911,8 @@ extern "C" { // // Tokenization // + // The API is thread-safe. + // /// @details Convert the provided text into tokens. /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. From 1b2f992cd2cff0b69e5abe78bb8888d51ed19d67 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 28 Sep 2024 14:32:46 +0200 Subject: [PATCH 10/28] test-backend-ops : use flops for some performance tests (#9657) * test-backend-ops : use flops for some performance tests - parallelize tensor quantization - use a different set of cases for performance and correctness tests - run each test for at least one second --- tests/test-backend-ops.cpp | 415 +++++++++++++++++++------------------ 1 file changed, 216 insertions(+), 199 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9a96cfc4c99de..d2cfe06b592cf 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -32,63 +32,52 @@ #include #include #include +#include #include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { - // static RNG initialization (revisit if n_threads stops being constant) - static const size_t n_threads = std::thread::hardware_concurrency(); - static std::vector generators = []() { - std::random_device rd; - std::vector vec; - vec.reserve(n_threads); - //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed - for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); } - return vec; - }(); - - size_t size = ggml_nelements(tensor); - std::vector data(size); + size_t nels = ggml_nelements(tensor); + std::vector data(nels); + { + // parallel initialization + static const size_t n_threads = std::thread::hardware_concurrency(); + // static RNG initialization (revisit if n_threads stops being constant) + static std::vector generators = []() { + std::random_device rd; + std::vector vec; + vec.reserve(n_threads); + //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed + for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); } + return vec; + }(); + + auto init_thread = [&](size_t ith, size_t start, size_t end) { + std::uniform_real_distribution distribution(min, max); + auto & gen = generators[ith]; + for (size_t i = start; i < end; i++) { + data[i] = distribution(gen); + } + }; - auto init_thread = [&](size_t ith, size_t start, size_t end) { - std::uniform_real_distribution distribution(min, max); - for (size_t i = start; i < end; i++) { - data[i] = distribution(generators[ith]); + std::vector> tasks; + tasks.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*nels/n_threads; + size_t end = (i+1)*nels/n_threads; + tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); } - }; - - std::vector threads; - threads.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*size/n_threads; - size_t end = (i+1)*size/n_threads; - threads.emplace_back(init_thread, i, start, end); - } - for (auto & t : threads) { - t.join(); - } - -#if 0 - const char * val_str = getenv("GGML_TEST_EPS"); - float val = 1e-9f; - if (val_str != nullptr) { - val = std::stof(val_str); - printf("GGML_TEST_EPS=%e\n", val); - } - - // test quantization with very small values that may result in nan scales due to division by zero - if (ggml_is_quantized(tensor->type)) { - for (int i = 0; i < 256; i++) { - data[i] = val; + for (auto & t : tasks) { + t.get(); } } -#endif if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) { - ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); + ggml_backend_tensor_set(tensor, data.data(), 0, nels * sizeof(float)); } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) { - GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); - std::vector dataq(ggml_row_size(tensor->type, size)); - std::vector imatrix(tensor->ne[0], 1.0f); // dummy importance matrix + GGML_ASSERT(nels % ggml_blck_size(tensor->type) == 0); + + // dummy importance matrix + std::vector imatrix(tensor->ne[0], 1.0f); const float * im = imatrix.data(); if (!ggml_quantize_requires_imatrix(tensor->type)) { // when the imatrix is optional, we want to test both quantization with and without imatrix @@ -98,15 +87,31 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } - ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im); - GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size())); - // TODO: other cases - //#pragma omp parallel for - //for (int i = 0; i < tensor->ne[1]; i++) { - // ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), - // i * tensor->ne[0], 1, tensor->ne[0], im); - //} - + std::vector dataq(ggml_row_size(tensor->type, nels)); + { + // parallel quantization by block + size_t blck_size = ggml_blck_size(tensor->type); + size_t n_blocks = nels / blck_size; + + auto quantize_thread = [&](size_t start, size_t end) { + ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), + start * blck_size, end - start, blck_size, im); + }; + + const size_t min_blocks_per_thread = 1; + const size_t n_threads = std::min(std::thread::hardware_concurrency()/2, + std::max(1, n_blocks / min_blocks_per_thread)); + std::vector> tasks; + tasks.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*n_blocks/n_threads; + size_t end = (i+1)*n_blocks/n_threads; + tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); + } + for (auto & t : tasks) { + t.get(); + } + } ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) { // This is going to create some weird integers though. @@ -160,60 +165,6 @@ static std::vector tensor_to_float(const ggml_tensor * t) { return tv; } -/* -static double cosine_similarity(const float * v1, const float * v2, size_t n) { - double dot = 0.0; - double mag1 = 0.0; - double mag2 = 0.0; - - for (size_t i = 0; i < n; i++) { - if (std::isnan(v1[i]) || std::isnan(v2[i])) { - return -1.0f; - } - if (std::isinf(v1[i]) && std::isinf(v2[i])) { - continue; - } - dot += v1[i]*v2[i]; - mag1 += v1[i]*v1[i]; - mag2 += v2[i]*v2[i]; - } - - return dot/sqrt(mag1*mag2); -} - -static float distance(const float * v1, const float * v2, size_t n) { - double d = 0.0; - - for (size_t i = 0; i < n; i++) { - if (std::isnan(v1[i]) || std::isnan(v2[i])) { - return INFINITY; - } - if (std::isinf(v1[i]) && std::isinf(v2[i])) { - continue; - } - d += (v1[i] - v2[i])*(v1[i] - v2[i]); - } - - return sqrt(d); -} - -static float vec_len(const float * v, size_t n) { - double d = 0.0; - - for (size_t i = 0; i < n; i++) { - if (std::isnan(v[i])) { - return INFINITY; - } - if (std::isinf(v[i])) { - continue; - } - d += v[i]*v[i]; - } - - return sqrt(d); -} -*/ - // normalized mean squared error = mse(a, b) / mse(a, 0) static double nmse(const float * a, const float * b, size_t n) { double mse_a_b = 0.0; @@ -264,7 +215,6 @@ static double mean_abs_asymm(const float * a, const float * b, const size_t n, c } // utils for printing the variables of the test cases -#define VAR_TO_STR(x) (#x "=" + var_to_str(x)) template static std::string var_to_str(const T & x) { @@ -297,10 +247,6 @@ static std::string var_to_str(const std::array & x) { return s; } -//static std::string var_to_str(ggml_unary_op unary_op) { -// return ggml_unary_op_name(unary_op); -//} - static std::string var_to_str(ggml_type type) { return ggml_type_name(type); } @@ -313,6 +259,8 @@ static std::string var_to_str(ggml_op_pool pool) { } } +#define VAR_TO_STR(x) (#x "=" + var_to_str(x)) + #define VARS_TO_STR1(a) VAR_TO_STR(a) #define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b) #define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c) @@ -370,13 +318,13 @@ struct test_case { return 1e-4; } - virtual float grad_eps(){ + virtual float grad_eps() { return 1e-1f; } // If false, estimate gradient with 2 points, neglects 3rd order derivative and higher. // If true, estimate gradient with 4 points, neglects 5th order derivative and higher. - virtual bool grad_precise(){ + virtual bool grad_precise() { return false; } @@ -409,6 +357,11 @@ struct test_case { return size; } + virtual uint64_t op_flops(ggml_tensor * t) { + GGML_UNUSED(t); + return 0; + } + ggml_cgraph * gf = nullptr; ggml_cgraph * gb = nullptr; @@ -651,12 +604,11 @@ struct test_case { } // align while also leaving some margin for variations in parameters - int align = 20; + int align = 8; int last = (len + align - 1) / align * align; if (last - len < 5) { last += align; } - last = std::max(last, 60); printf("%*s", last - len, ""); // allocate @@ -677,9 +629,25 @@ struct test_case { // warmup run ggml_backend_graph_compute(backend, gf); + // determine number of runs + int n_runs; + if (op_flops(out) > 0) { + // based on flops + const uint64_t GFLOP = 1000 * 1000 * 1000; + const uint64_t target_flops_cpu = 8ULL * GFLOP; + const uint64_t target_flops_gpu = 100ULL * GFLOP; + uint64_t target_flops = ggml_backend_is_cpu(backend) ? target_flops_cpu : target_flops_gpu; + n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; + } else { + // based on memory size + const size_t GB = 1ULL << 30; + const size_t target_size_cpu = 8 * GB; + const size_t target_size_gpu = 32 * GB; + size_t target_size = ggml_backend_is_cpu(backend) ? target_size_cpu : target_size_gpu; + n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; + } + // duplicate the op - size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU - int n_runs = std::min((size_t) ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; for (int i = 1; i < n_runs; i++) { ggml_graph_add_node(gf, out); } @@ -706,17 +674,47 @@ struct test_case { // run ggml_backend_synchronize(backend); - int64_t start_time = ggml_time_us(); - ggml_backend_graph_compute(backend, gf); - ggml_backend_synchronize(backend); - int64_t end_time = ggml_time_us(); - double time_us = end_time - start_time; + int64_t total_time_us = 0; + int total_runs = 0; + do { + int64_t start_time = ggml_time_us(); + ggml_backend_graph_compute(backend, gf); + ggml_backend_synchronize(backend); + int64_t end_time = ggml_time_us(); + + total_time_us += end_time - start_time; + total_runs += n_runs; + } while (total_time_us < 1000*1000); // run for at least 1 second + + printf(" %8d runs - %8.2f us/run - ", + total_runs, + (double)total_time_us / total_runs); + + if (op_flops(out) > 0) { + double flops_per_sec = (op_flops(out) * total_runs) / (total_time_us / 1e6); + auto format_flops = [](double flops) -> std::string { + char buf[256]; + if (flops >= 1e12) { + snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12); + } else if (flops >= 1e9) { + snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9); + } else if (flops >= 1e6) { + snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6); + } else { + snprintf(buf, sizeof(buf), "%6.2f KFLOP", flops / 1e3); + } + return buf; + }; + printf("%s/run - \033[1;34m%sS\033[0m", + format_flops(op_flops(out)).c_str(), + format_flops(flops_per_sec).c_str()); - printf(" %5d runs - %8.2f us/run - %8zu kB/run - \033[1;34m%7.2f GB/s\033[0m\n", - n_runs, - time_us / n_runs, - op_size(out) / 1024, - mem / (time_us/1e6) / 1024.0 / 1024.0 / 1024.0); + } else { + printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m", + op_size(out) / 1024, + mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0); + } + printf("\n"); ggml_backend_buffer_free(buf); @@ -1591,13 +1589,9 @@ struct test_mul_mat : public test_case { return 5e-4; } - size_t op_size(ggml_tensor * t) override { - size_t a = ggml_nbytes(t->src[0]) * n * nr[0] * nr[1]; - size_t b = ggml_nbytes(t->src[1]) * m; - size_t c = ggml_nbytes(t); - return a + b + c; - + uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); + return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1]; } test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, @@ -1641,13 +1635,9 @@ struct test_mul_mat_id : public test_case { return 5e-4; } - size_t op_size(ggml_tensor * t) override { - size_t a = ggml_nbytes(t->src[2]) * n; - size_t b = ggml_nbytes(t->src[1]) * m; - size_t c = ggml_nbytes(t); - return a + b + c; - + uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); + return 2 * m * k * n * n_used; } test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, @@ -3163,47 +3153,46 @@ struct test_falcon : public test_llm { // ########################################### // ## Section 3: GGML Op Test Instantiation ## // ########################################### +static const ggml_type all_types[] = { + GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, + GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, + GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0, + GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, + GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends + GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, + GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, + GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, +}; + +static const ggml_type base_types[] = { + GGML_TYPE_F32, GGML_TYPE_F16, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_K, + GGML_TYPE_IQ2_XXS +}; +static const ggml_type other_types[] = { + GGML_TYPE_Q4_1, + GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0, + GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, + GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends + GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, + GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, + GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, + GGML_TYPE_BF16, +}; -static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { +// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low +static std::vector> make_test_cases_eval() { std::vector> test_cases; std::default_random_engine rng(0); - const ggml_type all_types[] = { - GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, - GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, - GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, - GGML_TYPE_Q8_0, - GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, - GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, - GGML_TYPE_Q6_K, - // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends - GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, - GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, - GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, - }; - - const ggml_type base_types[] = { - GGML_TYPE_F32, GGML_TYPE_F16, - GGML_TYPE_Q4_0, - GGML_TYPE_Q4_K, - GGML_TYPE_IQ2_XXS - }; - - const ggml_type other_types[] = { - GGML_TYPE_Q4_1, - GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, - GGML_TYPE_Q8_0, - GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, - GGML_TYPE_Q5_K, - GGML_TYPE_Q6_K, - // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends - GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, - GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, - GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, - GGML_TYPE_BF16, - }; - // unary ops for (int v : {0, 1}) { for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { @@ -3392,6 +3381,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); } } + for (ggml_type type_a : other_types) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + if (ggml_blck_size(type_a) != 256) { + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1})); + } + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1})); + } + } #else // m = a rows // n = b rows @@ -3411,15 +3408,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } #endif - for (ggml_type type_a : other_types) { - for (ggml_type type_b : {GGML_TYPE_F32}) { - if (ggml_blck_size(type_a) != 256) { - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1})); - } - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1})); - } - } - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1})); @@ -3624,20 +3612,30 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_falcon(2)); #endif - // run tests - if (mode == MODE_GRAD) { - size_t n_ok = 0; - for (auto & test : test_cases) { - if (test->eval_grad(backend, op_name)) { - n_ok++; + return test_cases; +} + +// Test cases for performance evaluation: should be representative of real-world use cases +static std::vector> make_test_cases_perf() { + std::vector> test_cases; + + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); + + for (int bs : {1, 512}) { + for (ggml_type type_a : all_types) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1, 1}, {1, 1})); } } - printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); - - return n_ok == test_cases.size(); } + return test_cases; +} + +static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { if (mode == MODE_TEST) { + auto test_cases = make_test_cases_eval(); ggml_backend_t backend_cpu = ggml_backend_cpu_init(); size_t n_ok = 0; @@ -3653,7 +3651,21 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op return n_ok == test_cases.size(); } + if (mode == MODE_GRAD) { + auto test_cases = make_test_cases_eval(); + size_t n_ok = 0; + for (auto & test : test_cases) { + if (test->eval_grad(backend, op_name)) { + n_ok++; + } + } + printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); + + return n_ok == test_cases.size(); + } + if (mode == MODE_PERF) { + auto test_cases = make_test_cases_perf(); for (auto & test : test_cases) { test->eval_perf(backend, op_name); } @@ -3667,9 +3679,9 @@ static void usage(char ** argv) { printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]); printf(" valid modes:\n"); printf(" - test (default, compare with CPU backend for correctness)\n"); - printf(" - perf (performance evaluation)\n"); printf(" - grad (compare gradients from backpropagation with method of finite differences)\n"); - printf(" op names are as given by ggml_op_desc() (e.g. GGML_ADD)\n"); + printf(" - perf (performance evaluation)\n"); + printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n"); } int main(int argc, char ** argv) { @@ -3728,6 +3740,11 @@ int main(int argc, char ** argv) { continue; } + if (ggml_backend_is_cpu(backend)) { + // TODO: better value for n_threads + ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2); + } + printf(" Backend name: %s\n", ggml_backend_name(backend)); bool ok = test_backend(backend, mode, op_name_filter); From f4d2b8846a6b34419ff9e9491aee6cd95e444bfc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 28 Sep 2024 17:42:03 +0300 Subject: [PATCH 11/28] llama : add reranking support (#9510) * py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen Co-authored-by: Xuan Son Nguyen --- ci/run.sh | 85 ++++++- common/arg.cpp | 18 +- common/common.cpp | 5 + common/common.h | 1 + convert_hf_to_gguf.py | 27 ++- convert_hf_to_gguf_update.py | 1 + examples/embedding/embedding.cpp | 7 +- examples/server/README.md | 39 +++- examples/server/server.cpp | 215 ++++++++++++++++-- .../server/tests/features/embeddings.feature | 2 +- examples/server/tests/features/rerank.feature | 42 ++++ examples/server/tests/features/steps/steps.py | 54 ++++- examples/server/utils.hpp | 25 +- gguf-py/gguf/constants.py | 7 + gguf-py/gguf/tensor_mapping.py | 9 + include/llama.h | 10 +- src/llama-vocab.cpp | 15 +- src/llama.cpp | 96 +++++++- 18 files changed, 602 insertions(+), 56 deletions(-) create mode 100644 examples/server/tests/features/rerank.feature diff --git a/ci/run.sh b/ci/run.sh index 1ac08ee4e19a8..7d241ecc0ea06 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -712,6 +712,81 @@ function gg_run_embd_bge_small { set +e } +function gg_sum_embd_bge_small { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'BGE Small (BERT):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" + gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" +} + +# rerank_tiny + +function gg_run_rerank_tiny { + cd ${SRC} + + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + + gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json + + path_models="../models-mnt/rerank-tiny" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?hi\nwhat is panda?it's a bear\nwhat is panda?The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + + # sample output + # rerank score 0: 0.029 + # rerank score 1: 0.029 + # rerank score 2: 0.135 + + # check that the score is in the range [$3, $4] + function check_score { + qnt="$1" + score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) + + if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then + printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4" + return 20 + fi + + printf ' - %s @ %s OK\n' "$qnt" "$score" + return 0 + } + + check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.15" | tee -a $OUT/${ci}-rk-f16.log + + set +e +} + +function gg_sum_rerank_tiny { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Rerank Tiny (Jina):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)" +} + function gg_check_build_requirements { if ! command -v cmake &> /dev/null; then gg_printf 'cmake not found, please install' @@ -726,15 +801,6 @@ function gg_check_build_requirements { fi } -function gg_sum_embd_bge_small { - gg_printf '### %s\n\n' "${ci}" - - gg_printf 'BGE Small (BERT):\n' - gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" - gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" - gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" -} - ## main export LLAMA_LOG_PREFIX=1 @@ -762,6 +828,7 @@ test $ret -eq 0 && gg_run ctest_release if [ -z ${GG_BUILD_LOW_PERF} ]; then test $ret -eq 0 && gg_run embd_bge_small + test $ret -eq 0 && gg_run rerank_tiny if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then test $ret -eq 0 && gg_run test_scripts_debug diff --git a/common/arg.cpp b/common/arg.cpp index 6880117ed8001..8266a16c261c5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx params.kv_overrides.back().key[0] = 0; } + if (params.reranking && params.embedding) { + throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); + } + return true; } @@ -391,7 +395,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params) { params.verbose_prompt = true; } - ).set_examples({LLAMA_EXAMPLE_MAIN})); + )); add_opt(llama_arg( {"--no-display-prompt"}, format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"), @@ -1093,13 +1097,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, } ).set_sparam()); add_opt(llama_arg( - {"--pooling"}, "{none,mean,cls,last}", + {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", [](gpt_params & params, const std::string & value) { /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } - else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); @@ -1749,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.embedding = true; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + add_opt(llama_arg( + {"--reranking", "--rerank"}, + format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"), + [](gpt_params & params) { + params.reranking = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); add_opt(llama_arg( {"--api-key"}, "KEY", "API key to use for authentication (default: none)", diff --git a/common/common.cpp b/common/common.cpp index 8d0ed4f95a737..e2b8574bf77d7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; + if (params.reranking) { + cparams.embeddings = true; + cparams.pooling_type = LLAMA_POOLING_TYPE_RANK; + } + cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); diff --git a/common/common.h b/common/common.h index cb87c4479ed0a..8b84cf9ad45ee 100644 --- a/common/common.h +++ b/common/common.h @@ -271,6 +271,7 @@ struct gpt_params { int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_sep = "\n"; // separator of embendings + bool reranking = false; // enable reranking support on server // server params int32_t port = 8080; // server listens on this network port diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2cd5a8c11bc18..96a8830e9e7a3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -291,8 +291,13 @@ def prepare_tensors(self): bid = int(part) break - for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): - data: np.ndarray # type hint + for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + data = data_torch.squeeze().numpy() + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data.shape) == 0: + data = data_torch.numpy() + n_dims = len(data.shape) data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) @@ -592,6 +597,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e": # ref: https://huggingface.co/databricks/dbrx-base res = "dbrx" + if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448": + # ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + res = "jina-v1-en" if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en res = "jina-v2-en" @@ -2601,7 +2609,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) -@Model.register("XLMRobertaModel") +@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -2699,6 +2707,11 @@ def set_vocab(self): self.gguf_writer.add_add_eos_token(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "roberta.", remove the prefix + # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main + if name.startswith("roberta."): + name = name[8:] + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor if name == "embeddings.position_embeddings.weight": if self._position_offset is not None: @@ -3110,6 +3123,14 @@ def set_vocab(self): self.gguf_writer.add_add_bos_token(True) self.gguf_writer.add_add_eos_token(True) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "bert.", remove the prefix + # e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + if name.startswith("bert."): + name = name[5:] + + return super().modify_tensors(data_torch, name, bid) + @Model.register("OpenELMForCausalLM") class OpenELMModel(Model): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 4d11059f374d2..022354a3b624e 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -81,6 +81,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, + {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", }, {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index a438dcb5adf34..7349268223827 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -135,7 +135,7 @@ int main(int argc, char ** argv) { // tokenize the prompts and trim std::vector> inputs; for (const auto & prompt : prompts) { - auto inp = ::llama_tokenize(ctx, prompt, true, false); + auto inp = ::llama_tokenize(ctx, prompt, true, true); if (inp.size() > n_batch) { LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", __func__, (long long int) inp.size(), (long long int) n_batch); @@ -234,6 +234,11 @@ int main(int argc, char ** argv) { } LOG("\n"); } + } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + for (int j = 0; j < n_embd_count; j++) { + // NOTE: if you change this log - update the tests in ci/run.sh + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); + } } else { // print the first part of the embeddings or for a single prompt, the full embedding for (int j = 0; j < n_prompts; j++) { diff --git a/examples/server/README.md b/examples/server/README.md index dfca07f988824..951c4a44c6058 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching * Multimodal (wip) @@ -23,6 +24,7 @@ The project is under active development, and we are [looking for feedback and co | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | +| `--verbose-prompt` | print a verbose prompt before generation (default: false) | | `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | @@ -130,7 +132,7 @@ The project is under active development, and we are [looking for feedback and co | `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | | `-sp, --special` | special tokens output enabled (default: false) | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | -| `--pooling {none,mean,cls,last}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | +| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | | `-a, --alias STRING` | set alias for model name (to be used by REST API)
(env: LLAMA_ARG_ALIAS) | @@ -138,6 +140,7 @@ The project is under active development, and we are [looking for feedback and co | `--port PORT` | port to listen (default: 8080)
(env: LLAMA_ARG_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | +| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication (default: none)
(env: LLAMA_API_KEY) | | `--api-key-file FNAME` | path to file containing API keys (default: none) | | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key
(env: LLAMA_ARG_SSL_KEY_FILE) | @@ -152,6 +155,7 @@ The project is under active development, and we are [looking for feedback and co | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | + Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. Example usage of docker compose with environment variables: @@ -478,6 +482,39 @@ The same as [the embedding example](../embedding) does. `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. +### POST `/reranking`: Rerank documents according to a given query + +Similar to https://jina.ai/reranker/ but might change in the future. +Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options. + + *Options:* + + `query`: The query against which the documents will be ranked. + + `documents`: An array strings representing the documents to be ranked. + + *Aliases:* + - `/rerank` + - `/v1/rerank` + - `/v1/reranking` + + *Examples:* + + ```shell + curl http://127.0.0.1:8012/v1/rerank \ + -H "Content-Type: application/json" \ + -d '{ + "model": "some-model", + "query": "What is panda?", + "top_n": 3, + "documents": [ + "hi", + "it is a bear", + "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." + ] + }' | jq + ``` + ### POST `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 61ff09bb2b40f..f343cc252f89a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -92,6 +92,7 @@ enum server_task_type { enum server_task_cmpl_type { SERVER_TASK_CMPL_TYPE_NORMAL, SERVER_TASK_CMPL_TYPE_EMBEDDING, + SERVER_TASK_CMPL_TYPE_RERANK, SERVER_TASK_CMPL_TYPE_INFILL, }; @@ -172,6 +173,7 @@ struct server_slot { std::vector generated_token_probs; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -954,8 +956,17 @@ struct server_context { slot.prompt = *prompt; } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { slot.prompt = prompt->at(0); + } else if (prompt->is_array() && prompt->size() > 1) { + // array of strings + for (const auto & el : *prompt) { + if (!el.is_string()) { + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + slot.prompt = *prompt; } else { - send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); return false; } } @@ -1389,6 +1400,7 @@ struct server_context { res.data = json { {"embedding", std::vector(n_embd, 0.0f)}, + {"index", slot.index}, }; continue; @@ -1407,6 +1419,44 @@ struct server_context { queue_results.send(res); } + void send_rerank(const server_slot & slot, const llama_batch & batch) { + server_task_result res; + res.id = slot.id_task; + res.error = false; + res.stop = true; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res.data = json { + {"index", slot.index}, + {"score", -1e6}, + }; + + continue; + } + + res.data = json { + {"index", slot.index}, + {"score", embd[0]}, + }; + } + + SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str()); + + queue_results.send(res); + } + // // Functions to create new task(s) and receive result(s) // @@ -1442,13 +1492,27 @@ struct server_context { // otherwise, it's a multiple-prompt task, we break it into smaller tasks else if (prompt.is_array()) { std::vector prompts = prompt; - for (size_t i = 0; i < prompts.size(); i++) { - const auto & e = prompts[i]; - if (e.is_string() || json_is_array_of_numbers(e)) { - data["index"] = i; - create_task(data, true, e); - } else { - throw std::runtime_error(error_msg); + if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // prompts[0] is the question + // the rest are the answers/documents + SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1); + for (size_t i = 1; i < prompts.size(); i++) { + json qd; + qd.push_back(prompts[0]); + qd.push_back(prompts[i]); + data["index"] = i - 1; + create_task(data, true, qd); + } + } else { + SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size()); + for (size_t i = 0; i < prompts.size(); i++) { + const auto & e = prompts[i]; + if (e.is_string() || json_is_array_of_numbers(e)) { + data["index"] = i; + create_task(data, true, e); + } else { + throw std::runtime_error(error_msg); + } } } } @@ -1492,7 +1556,9 @@ struct server_context { return; } - size_t idx = result.data["index"]; + const size_t idx = result.data["index"]; + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = result; } result_handler(results); @@ -1903,6 +1969,7 @@ struct server_context { // track if this is an embedding or non-embedding batch // if we've added sampled tokens above, we are in non-embedding mode // -1: none, 0: non-embedding, 1: embedding + // TODO: make enum int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; // next, batch any pending prompts without exceeding n_batch @@ -1951,6 +2018,29 @@ struct server_context { } prompt_tokens = embd_inp; + } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // require slot.prompt to be array of 2 strings + if (!slot.prompt.is_array() || slot.prompt.size() != 2) { + SLT_ERR(slot, "%s", "invalid prompt for rerank task\n"); + slot.release(); + send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST); + continue; + } + + // prompt: querydoc + prompt_tokens.clear(); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[0], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[1], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt } @@ -1970,7 +2060,7 @@ struct server_context { continue; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { slot.release(); @@ -2048,7 +2138,8 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2056,7 +2147,10 @@ struct server_context { } // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; + const bool slot_type = + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; + if (batch_type == -1) { batch_type = slot_type; } else if (batch_type != slot_type) { @@ -2229,6 +2323,13 @@ struct server_context { continue; // continue loop of slots } + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; } else if (slot.state != SLOT_STATE_GENERATING) { @@ -2787,8 +2888,8 @@ int main(int argc, char ** argv) { }; const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + if (ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2848,8 +2949,8 @@ int main(int argc, char ** argv) { // TODO: maybe merge this function with "handle_completions_generic" const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + if (ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2973,6 +3074,11 @@ int main(int argc, char ** argv) { }; const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + // TODO: somehow clean up this checks in the future + if (!ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } const json body = json::parse(req.body); bool is_openai = false; @@ -3023,6 +3129,79 @@ int main(int argc, char ** argv) { res_ok(res, root); }; + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (!ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + const json body = json::parse(req.body); + + // TODO: implement + //int top_n = 1; + //if (body.count("top_n") != 1) { + // top_n = body.at("top_n"); + //} else { + // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + // return; + //} + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } else { + res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::vector documents = json_value(body, "documents", std::vector()); + if (documents.empty()) { + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + // construct prompt object: array of ["query", "doc0", "doc1", ...] + json prompt; + prompt.push_back(query); + for (const auto & doc : documents) { + prompt.push_back(doc); + } + + LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str()); + + // create and queue the task + json responses = json::array(); + bool error = false; + { + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + for (const auto & res : results) { + responses.push_back(res.data); + } + }, [&](const json & error_data) { + res_error(res, error_data); + error = true; + }); + } + + if (error) { + return; + } + + // write JSON response + json root = format_response_rerank(body, responses); + res_ok(res, root); + }; + const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); for (size_t i = 0; i < ctx_server.loras.size(); ++i) { @@ -3119,6 +3298,10 @@ int main(int argc, char ** argv) { svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/rerank", handle_rerank); + svr->Post("/reranking", handle_rerank); + svr->Post("/v1/rerank", handle_rerank); + svr->Post("/v1/reranking", handle_rerank); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); // LoRA adapters hotswap diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature index 818ea3beb90cd..f4fe2ee4335ff 100644 --- a/examples/server/tests/features/embeddings.feature +++ b/examples/server/tests/features/embeddings.feature @@ -15,7 +15,7 @@ Feature: llama.cpp server And 128 as batch size And 128 as ubatch size And 512 KV cache size - And embeddings extraction + And enable embeddings endpoint Then the server is starting Then the server is healthy diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature new file mode 100644 index 0000000000000..c36cc8e215fa6 --- /dev/null +++ b/examples/server/tests/features/rerank.feature @@ -0,0 +1,42 @@ +@llama.cpp +@rerank +Feature: llama.cpp server + + Background: Server startup + Given a server listening on localhost:8080 + And a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf + And a model file jina-reranker-v1-tiny-en.gguf + And a model alias jina-reranker-v1-tiny-en + And 42 as server seed + And 2 slots + And 512 as batch size + And 512 as ubatch size + And 512 KV cache size + And enable reranking endpoint + Then the server is starting + Then the server is healthy + + Scenario: Rerank + Given a rerank query: + """ + Machine learning is + """ + And a rerank document: + """ + A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines. + """ + And a rerank document: + """ + Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants. + """ + And a rerank document: + """ + Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions. + """ + And a rerank document: + """ + Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine. + """ + When reranking request + Then reranking results are returned + Then reranking highest score is index 2 and lowest score is index 3 diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 0fea0fe87b799..2611614ba3633 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.server_api_key = None context.server_continuous_batching = False context.server_embeddings = False + context.server_reranking = False context.server_metrics = False context.server_process = None context.seed = None @@ -83,6 +84,10 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.concurrent_tasks = [] context.prompts = [] + context.reranking_query = None + context.reranking_documents = [] + context.reranking_results = None + @step('a model file {hf_file} from HF repo {hf_repo}') def step_download_hf_model(context, hf_file: str, hf_repo: str): @@ -172,10 +177,13 @@ def step_server_continuous_batching(context): context.server_continuous_batching = True -@step('embeddings extraction') +@step('enable embeddings endpoint') def step_server_embeddings(context): context.server_embeddings = True +@step('enable reranking endpoint') +def step_server_reranking(context): + context.server_reranking = True @step('prometheus compatible metrics exposed') def step_server_metrics(context): @@ -452,6 +460,14 @@ def step_impl(context, n_ga_w): def step_prompt_passkey(context): context.prompt_passkey = context_text(context) +@step('a rerank query') +def step_set_rerank_query(context): + context.reranking_query = context_text(context) + context.reranking_documents = [] + +@step('a rerank document') +def step_set_rerank_document(context): + context.reranking_documents.append(context_text(context)) @step('{n_prompts:d} fixed prompts') def step_fixed_prompts(context, n_prompts): @@ -619,6 +635,22 @@ async def step_compute_embedding(context): context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url) +@step('reranking request') +@async_run_until_complete +async def step_compute_reranking(context): + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: + async with session.post(f'{context.base_url}/reranking', + json={ + "query": context.reranking_query, + "documents": context.reranking_documents, + }) as response: + if response.status == 200: + response_json = await response.json() + context.reranking_results = response_json['results'] + else: + context.reranking_results = response.status + + @step('all embeddings are the same') @async_run_until_complete async def step_all_embeddings_are_the_same(context): @@ -704,6 +736,24 @@ async def all_embeddings_are_generated(context): for i in range(n_embedding_requests): assert_embeddings(context.tasks_result.pop().pop()) +@step('reranking results are returned') +def reranking_results_are_returned(context): + assert len(context.reranking_results) == len(context.reranking_documents) + +@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}') +def reranking_results_are_returned(context, idx_high: int, idx_low: int): + max_score, max_idx = 0, 0 + min_score, min_idx = 0, 0 + for res in context.reranking_results: + if max_score < res['relevance_score']: + max_score = res['relevance_score'] + max_idx = res['index'] + if min_score > res['relevance_score']: + min_score = res['relevance_score'] + min_idx = res['index'] + print(context.reranking_results) + assert max_idx == idx_high + assert min_idx == idx_low @step('adding special tokens') def step_tokenize_set_add_special(context): @@ -1362,6 +1412,8 @@ def start_server_background(context): server_args.append('--cont-batching') if context.server_embeddings: server_args.append('--embedding') + if context.server_reranking: + server_args.append('--reranking') if context.server_metrics: server_args.append('--metrics') if context.model_alias: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index f093f547ff2c1..47dfdfde512dc 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -537,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { + {"usage", json { // TODO: fill {"prompt_tokens", 0}, {"total_tokens", 0} }}, @@ -547,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } +static json format_response_rerank(const json & request, const json & ranks) { + json data = json::array(); + int i = 0; + for (const auto & rank : ranks) { + data.push_back(json{ + {"index", i++}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + } + + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { // TODO: fill + {"prompt_tokens", 0}, + {"total_tokens", 0} + }}, + {"results", data} + }; + + return res; +} + static bool is_valid_utf8(const std::string & str) { const unsigned char* bytes = reinterpret_cast(str.data()); const unsigned char* end = bytes + str.length(); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2fd2e9d2be828..ebe66a4a39f5f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -345,6 +345,8 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_DOWN = auto() ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() + CLS = auto() # classifier + CLS_OUT = auto() # classifier output projection MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -504,6 +506,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + MODEL_TENSOR.CLS: "cls", + MODEL_TENSOR.CLS_OUT: "cls.output", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -613,6 +617,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, @@ -644,6 +650,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 5ef91f11d312f..e7e9b6fd5efbc 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -679,6 +679,15 @@ class TensorNameMap: MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 ), + + MODEL_TENSOR.CLS: ( + "classifier", # jina + "classifier.dense", # roberta + ), + + MODEL_TENSOR.CLS_OUT: ( + "classifier.out_proj", # roberta + ), } # architecture-specific block mappings diff --git a/include/llama.h b/include/llama.h index 4ea8a2c2b664b..7cae1bbe2e5b8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -193,6 +193,7 @@ extern "C" { LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_LAST = 3, + LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph }; enum llama_attention_type { @@ -202,9 +203,9 @@ extern "C" { }; enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -872,7 +873,8 @@ extern "C" { // Get the embeddings for a sequence id // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE - // shape: [n_embd] (1-dimensional) + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence + // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); // diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e4d844a73c216..d2f34ddd6b339 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1554,7 +1554,7 @@ std::vector llama_tokenize_internal( } break; case LLAMA_VOCAB_TYPE_UGM: { - if (add_special && vocab.tokenizer_add_bos != 0) { + if (add_special && vocab.tokenizer_add_bos) { GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); } @@ -1572,14 +1572,14 @@ std::vector llama_tokenize_internal( } } - if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { + if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { LLAMA_LOG_WARN( "%s: Added a BOS token to the prompt as specified by the model but the prompt " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (add_special && vocab.tokenizer_add_eos == 1) { + if (add_special && vocab.tokenizer_add_eos) { GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); } @@ -1791,11 +1791,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = token_text; llama_unescape_whitespace(result); return _try_copy(result.data(), result.size()); - } else if (attr & LLAMA_TOKEN_ATTR_BYTE) { + } + if (attr & LLAMA_TOKEN_ATTR_BYTE) { char byte = (char) llama_token_to_byte(vocab, token); return _try_copy((char*) &byte, 1); } @@ -1806,7 +1808,8 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = llama_decode_text(token_text); return _try_copy(result.data(), result.size()); } diff --git a/src/llama.cpp b/src/llama.cpp index 44afb31d74e53..c466cd88b7c14 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -606,6 +606,8 @@ enum llm_tensor { LLM_TENSOR_ENC_FFN_DOWN, LLM_TENSOR_ENC_FFN_UP, LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, }; static const std::map> LLM_TENSOR_NAMES = { @@ -793,6 +795,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, { @@ -828,6 +832,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, }, }, { @@ -2894,6 +2899,7 @@ struct llama_model { llama_hparams hparams = {}; llama_vocab vocab; + // TODO: should init all tensors to nullptr struct ggml_tensor * tok_embd; struct ggml_tensor * type_embd; struct ggml_tensor * pos_embd; @@ -2906,6 +2912,12 @@ struct llama_model { struct ggml_tensor * output_b; struct ggml_tensor * output_norm_enc; + // classifier + struct ggml_tensor * cls; + struct ggml_tensor * cls_b; + struct ggml_tensor * cls_out = nullptr; + struct ggml_tensor * cls_out_b = nullptr; + std::vector layers; llama_split_mode split_mode; @@ -5604,11 +5616,11 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { - case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small + case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base } } break; @@ -6313,6 +6325,7 @@ static void llm_load_vocab( tokenizer_pre == "phi-2" || tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || + tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "jina-v2-code") { @@ -6439,7 +6452,12 @@ static void llm_load_vocab( for (uint32_t i = 0; i < n_vocab; i++) { std::string word = gguf_get_arr_str(ctx, token_idx, i); - GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } vocab.token_to_id[word] = i; vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); @@ -6520,8 +6538,14 @@ static void llm_load_vocab( vocab.linefeed_id = ids[0]; } else { const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A - GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - vocab.linefeed_id = ids[0]; + + //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + if (ids.empty()) { + LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); + vocab.linefeed_id = vocab.special_pad_id; + } else { + vocab.linefeed_id = ids[0]; + } } // special tokens @@ -7394,6 +7418,12 @@ static bool llm_load_tensors( if (model.arch == LLM_ARCH_BERT) { model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + + model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + model.cls_out = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_out_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); @@ -7446,6 +7476,8 @@ static bool llm_load_tensors( model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias + model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); @@ -10279,6 +10311,10 @@ struct llm_build_context { struct ggml_tensor * cur; switch (pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; case LLAMA_POOLING_TYPE_MEAN: { struct ggml_tensor * inp_mean = build_inp_mean(); @@ -10290,9 +10326,26 @@ struct llm_build_context { struct ggml_tensor * inp_cls = build_inp_cls(); cur = ggml_get_rows(ctx0, inp, inp_cls); } break; - case LLAMA_POOLING_TYPE_NONE: + case LLAMA_POOLING_TYPE_RANK: { - cur = inp; + struct ggml_tensor * inp_cls = build_inp_cls(); + inp = ggml_get_rows(ctx0, inp, inp_cls); + + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + GGML_ASSERT(model.cls != nullptr); + GGML_ASSERT(model.cls_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b); + cur = ggml_tanh(ctx0, cur); + + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 + if (model.cls_out) { + GGML_ASSERT(model.cls_out_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); + } } break; default: { @@ -11521,8 +11574,8 @@ struct llm_build_context { inpL = cur; } - // final output cur = inpL; + cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); @@ -16682,7 +16735,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { const int64_t n_tokens = batch.n_tokens; const int64_t n_seq_tokens = batch.n_seq_tokens; const int64_t n_seqs = batch.n_seqs; @@ -16697,7 +16752,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const llama_seq_id seq_id = batch.seq_id[s][0]; // TODO: adapt limits to n_seqs when batch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); for (int i = 0; i < n_seq_tokens; ++i) { const llama_pos pos = batch.pos[s*n_seq_tokens + i]; @@ -17237,6 +17292,20 @@ static int llama_decode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - a single float per sequence + auto & embd_seq_out = lctx.embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); @@ -17443,6 +17512,13 @@ static int llama_encode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); From 589b48d41efb0e95133b77c335f4fb9779af9bfb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Sep 2024 14:38:18 +0300 Subject: [PATCH 12/28] contrib : add Resources section (#9675) --- CONTRIBUTING.md | 5 +++++ README.md | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a9e000e5227d9..3d7c6f86ca73e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,3 +27,8 @@ ![matmul](media/matmul.png) +# Resources + +The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: + +https://github.com/ggerganov/llama.cpp/projects diff --git a/README.md b/README.md index a452a6d786948..ecc2df8ca832d 100644 --- a/README.md +++ b/README.md @@ -443,7 +443,7 @@ To learn more how to measure perplexity using llama.cpp, [read this documentatio - Contributors can open PRs - Collaborators can push to branches in the `llama.cpp` repo and merge PRs into the `master` branch - Collaborators will be invited based on contributions -- Any help with managing issues and PRs is very appreciated! +- Any help with managing issues, PRs and projects is very appreciated! - See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions - Read the [CONTRIBUTING.md](CONTRIBUTING.md) for more information - Make sure to read this: [Inference at the edge](https://github.com/ggerganov/llama.cpp/discussions/205) From f99d3f8367174f7aba73c07fd87de687d4a0ece1 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:02:06 +0000 Subject: [PATCH 13/28] py : add model class for Chameleon conversion (#9683) --- convert_hf_to_gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 96a8830e9e7a3..f3857d487ca30 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4162,7 +4162,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) -@Model.register("ChameleonForCausalLM") +@Model.register("ChameleonForConditionalGeneration") +@Model.register("ChameleonForCausalLM") # obsolete class ChameleonModel(Model): model_arch = gguf.MODEL_ARCH.CHAMELEON From faac0bae265449fd988c57bf894018edc36fbe1e Mon Sep 17 00:00:00 2001 From: matiaslin <45382001+matiaslin@users.noreply.github.com> Date: Sun, 29 Sep 2024 05:25:00 -0700 Subject: [PATCH 14/28] common : ensure llama_batch size does not exceed max size (#9668) A crash was observed when the number of tokens added to a batch exceeds llama_batch size. An assertion in llama_batch_add was added to protect against llama_batch size overflow. --- common/common.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index e2b8574bf77d7..a0611f3d1734b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1437,6 +1437,8 @@ void llama_batch_add( llama_pos pos, const std::vector & seq_ids, bool logits) { + GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); + batch.token [batch.n_tokens] = id; batch.pos [batch.n_tokens] = pos; batch.n_seq_id[batch.n_tokens] = seq_ids.size(); From 6084bfb261b03f812de2255b05b6b5bb8d1c7171 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Sep 2024 13:23:59 +0300 Subject: [PATCH 15/28] ggml : fix GGML_MAX_N_THREADS + improve formatting (ggml/969) --- ggml/include/ggml.h | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 9f96e0c489b38..f46d4a8a65f02 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -229,14 +229,16 @@ #define GGML_MAX_PARAMS 2048 #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 10 -#ifndef GGML_MAX_NAME -#define GGML_MAX_NAME 64 #define GGML_MAX_N_THREADS 512 +#define GGML_MAX_OP_PARAMS 64 +#ifndef GGML_MAX_NAME +# define GGML_MAX_NAME 64 #endif -#define GGML_MAX_OP_PARAMS 64 + #define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_GRAPH_SIZE 2048 + #if UINTPTR_MAX == 0xFFFFFFFF #define GGML_MEM_ALIGN 4 #else @@ -259,21 +261,21 @@ #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) #ifndef NDEBUG -#define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) +# define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) #elif defined(__GNUC__) -#define GGML_UNREACHABLE() __builtin_unreachable() +# define GGML_UNREACHABLE() __builtin_unreachable() #elif defined(_MSC_VER) -#define GGML_UNREACHABLE() __assume(0) +# define GGML_UNREACHABLE() __assume(0) #else -#define GGML_UNREACHABLE() ((void) 0) +# define GGML_UNREACHABLE() ((void) 0) #endif #ifdef __cplusplus -#define GGML_NORETURN [[noreturn]] +# define GGML_NORETURN [[noreturn]] #elif defined(_MSC_VER) -#define GGML_NORETURN __declspec(noreturn) +# define GGML_NORETURN __declspec(noreturn) #else -#define GGML_NORETURN _Noreturn +# define GGML_NORETURN _Noreturn #endif #define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__) From 544f409b4bd8fc98a3e87820f0ac934e00402de7 Mon Sep 17 00:00:00 2001 From: Salvatore Mesoraca Date: Thu, 26 Sep 2024 08:59:42 +0200 Subject: [PATCH 16/28] vulkan : argsort barriers must be under uniform control flow (ggml/951) a return before a barrier (that happens only in some threads in a workgroup) leads to UB. While the old code actually works on some devices, it fails on some others (i.e. "smaller" GPUs). BTW, I think it would be better to set specialization constants when the graph is built, in that way the local workgroup could be sized appropriately. But it would take a lot of work. Signed-off-by: Salvatore Mesoraca --- ggml/src/vulkan-shaders/argsort.comp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ggml/src/vulkan-shaders/argsort.comp b/ggml/src/vulkan-shaders/argsort.comp index e55414b03c519..d4fa45b1e106f 100644 --- a/ggml/src/vulkan-shaders/argsort.comp +++ b/ggml/src/vulkan-shaders/argsort.comp @@ -29,20 +29,18 @@ void main() { const int col = int(gl_LocalInvocationID.x); const uint row = gl_WorkGroupID.y; - if (col >= p.ncols_pad) { - return; - } - const uint row_offset = row * p.ncols; // initialize indices - dst_row[col] = col; + if (col < p.ncols_pad) { + dst_row[col] = col; + } barrier(); for (uint k = 2; k <= p.ncols_pad; k *= 2) { for (uint j = k / 2; j > 0; j /= 2) { const uint ixj = col ^ j; - if (ixj > col) { + if (col < p.ncols_pad && ixj > col) { if ((col & k) == 0) { if (dst_row[col] >= p.ncols || (dst_row[ixj] < p.ncols && (p.order == ASC ? From 0de8b203f1d31cf5ee0d2a3560a0ad78d44f4d4c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 27 Sep 2024 02:58:01 -0500 Subject: [PATCH 17/28] vulkan : fix build for GGML_VULKAN_RUN_TESTS, add TFLOPS to log (ggml/961) --- ggml/src/ggml-vulkan.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index a877145e82b49..70b7291540b2b 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -5013,6 +5013,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } + ggml_pipeline_allocate_descriptor_sets(ctx->device); + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -5129,7 +5131,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t avg_err /= m * n; - std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms avg_err=" << avg_err << std::endl; + double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; if (avg_err > 0.1) { std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; @@ -5251,12 +5255,14 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ggml_vk_ctx_begin(ctx->device, subctx); const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; - ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); ggml_vk_ctx_end(subctx); auto begin = std::chrono::high_resolution_clock::now(); @@ -5383,6 +5389,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } } + ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ggml_vk_buffer_write(y_buf, 0, y, y_sz); @@ -5450,7 +5458,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, avg_err /= m * n; - std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl; + double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; @@ -5502,9 +5512,6 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { #if defined(GGML_VULKAN_RUN_TESTS) - ctx->staging = ggml_vk_create_buffer_check(ctx->device, 100ul * 1024ul * 1024ul, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32); ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0); ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1); From 641002fba8d2a0c0269027e23d2ef58e90546028 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 29 Sep 2024 11:50:17 -0500 Subject: [PATCH 18/28] vulkan : multithread pipeline creation (ggml/963) --- ggml/src/ggml-vulkan.cpp | 41 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 70b7291540b2b..c677a27287cc0 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -607,13 +609,16 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend); -static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector specialization_constants, uint32_t align) { VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")"); GGML_ASSERT(parameter_count > 0); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT - std::lock_guard guard(device->mutex); - pipeline = std::make_shared(); pipeline->name = name; pipeline->parameter_count = parameter_count; @@ -681,7 +686,17 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co pipeline->layout); pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; - device->pipelines.insert({ pipeline->name, pipeline }); + { + std::lock_guard guard(device->mutex); + device->pipelines.insert({ pipeline->name, pipeline }); + } + + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); } static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { @@ -1194,6 +1209,20 @@ static void ggml_vk_load_shaders(vk_device& device) { device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared(); device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared(); + std::vector> compiles; + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align)); + }; + if (device->fp16) { ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); @@ -1743,6 +1772,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + + for (auto &c : compiles) { + c.wait(); + } } static vk_device ggml_vk_get_device(size_t idx) { From aaa40999251f5d18309b3fddc5a7d576f5fdb4e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 29 Sep 2024 19:56:17 +0200 Subject: [PATCH 19/28] CUDA: remove bad assert (ggml/972) --- ggml/src/ggml-cuda/im2col.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 3d0d8d4e6c686..16463ab0fb683 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -69,7 +69,6 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); From d0b1d663e430354ab35853a6e1bce51cc8819376 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Sep 2024 21:16:07 +0300 Subject: [PATCH 20/28] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 36eeed0cc85c9..aa301462a9a78 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -336c10a4c3c8ec99af484b25a0cddd397a09cdb2 +9a24b8c8c40eab7262d067e91d08df160678df8d From c919d5db39c8a7fcb64737f008e4b105ee0acd20 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Sep 2024 21:18:23 +0300 Subject: [PATCH 21/28] ggml : define missing HWCAP flags (#9684) ggml-ci Co-authored-by: Willy Tarreau --- ggml/src/ggml.c | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index fac4466e31d44..81b651c6a438d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3687,6 +3687,10 @@ static inline int ggml_up(int n, int m) { #include #endif +#if !defined(HWCAP2_I8MM) +#define HWCAP2_I8MM 0 +#endif + static void ggml_init_arm_arch_features(void) { #if defined(__linux__) && defined(__aarch64__) uint32_t hwcap = getauxval(AT_HWCAP); From 8277a817f18967581b02b2248989d773e8e99998 Mon Sep 17 00:00:00 2001 From: Ruchira Hasaranga Date: Mon, 30 Sep 2024 13:53:42 +0530 Subject: [PATCH 22/28] console : utf-8 fix for windows stdin (#9690) * utf-8 fix for windows stdin * Update common/console.cpp --------- Co-authored-by: Georgi Gerganov --- common/console.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/common/console.cpp b/common/console.cpp index f65cbc6eda0b1..078a8d678d933 100644 --- a/common/console.cpp +++ b/common/console.cpp @@ -94,6 +94,9 @@ namespace console { simple_io = true; } } + if (simple_io) { + _setmode(_fileno(stdin), _O_U8TEXT); + } #else // POSIX-specific console initialization if (!simple_io) { From ace4f4be37abed4801fbd54a94cf38a7ae462416 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 30 Sep 2024 17:48:49 +0300 Subject: [PATCH 23/28] flake.lock: Update (#9680) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flake lock file updates: • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/c04d5652cfa9742b1d519688f65d1bbccea9eb7e?narHash=sha256-PmUr/2GQGvFTIJ6/Tvsins7Q43KTMvMFhvG6oaYK%2BWk%3D' (2024-09-19) → 'github:NixOS/nixpkgs/1925c603f17fc89f4c8f6bf6f631a802ad85d784?narHash=sha256-J%2BPeFKSDV%2BpHL7ukkfpVzCOO7mBSrrpJ3svwBFABbhI%3D' (2024-09-26) Co-authored-by: github-actions[bot] --- flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 6333a09f0106a..dde1ab5277afb 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1726755586, - "narHash": "sha256-PmUr/2GQGvFTIJ6/Tvsins7Q43KTMvMFhvG6oaYK+Wk=", + "lastModified": 1727348695, + "narHash": "sha256-J+PeFKSDV+pHL7ukkfpVzCOO7mBSrrpJ3svwBFABbhI=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "c04d5652cfa9742b1d519688f65d1bbccea9eb7e", + "rev": "1925c603f17fc89f4c8f6bf6f631a802ad85d784", "type": "github" }, "original": { From 08a43d05b6ba74de97610ae519450ad9996475e0 Mon Sep 17 00:00:00 2001 From: vb Date: Mon, 30 Sep 2024 17:03:47 +0200 Subject: [PATCH 24/28] py : update transfomers version (#9694) * update transfomers version. * update hfh version. --- examples/server/tests/requirements.txt | 2 +- requirements/requirements-convert_legacy_llama.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index f2d7e5c5731be..5539548720ff1 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -1,6 +1,6 @@ aiohttp~=3.9.3 behave~=1.2.6 -huggingface_hub~=0.20.3 +huggingface_hub~=0.23.2 numpy~=1.26.4 openai~=1.30.3 prometheus-client~=0.20.0 diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index 1d07b09522f61..859204b27ebb8 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,5 +1,5 @@ numpy~=1.26.4 sentencepiece~=0.2.0 -transformers>=4.40.1,<5.0.0 +transformers>=4.45.1,<5.0.0 gguf>=0.1.0 protobuf>=4.21.0,<5.0.0 From 511636df0c90826b4dd1fc21ff260c19d69a3b5d Mon Sep 17 00:00:00 2001 From: compilade Date: Mon, 30 Sep 2024 14:13:16 -0400 Subject: [PATCH 25/28] ci : reduce severity of unused Pyright ignore comments (#9697) --- .github/workflows/python-type-check.yml | 4 +++- examples/llava/convert_image_encoder_to_gguf.py | 4 ++-- pyrightconfig.json | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml index e5ff5e6d792cb..373bb601020b2 100644 --- a/.github/workflows/python-type-check.yml +++ b/.github/workflows/python-type-check.yml @@ -4,11 +4,13 @@ on: push: paths: - '.github/workflows/python-type-check.yml' + - 'pyrightconfig.json' - '**.py' - '**/requirements*.txt' pull_request: paths: - '.github/workflows/python-type-check.yml' + - 'pyrightconfig.json' - '**.py' - '**/requirements*.txt' @@ -33,6 +35,6 @@ jobs: - name: Type-check with Pyright uses: jakebailey/pyright-action@v2 with: - version: 1.1.370 + version: 1.1.382 level: warning warnings: true diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 36f6b92fbd46e..4fa1d6ceae1bb 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -274,7 +274,7 @@ def bytes_to_unicode(): if has_llava_projector: - model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue] + model.vision_model.encoder.layers.pop(-1) projector = torch.load(args.llava_projector) for name, data in projector.items(): name = get_tensor_name(name) @@ -288,7 +288,7 @@ def bytes_to_unicode(): print("Projector tensors added\n") -state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue] +state_dict = model.state_dict() for name, data in state_dict.items(): if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector): # we don't need this diff --git a/pyrightconfig.json b/pyrightconfig.json index 6016f4b6d0120..9acbbeb78a2ed 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -5,7 +5,8 @@ "reportUnusedImport": "warning", "reportDuplicateImport": "error", "reportDeprecated": "warning", - "reportUnnecessaryTypeIgnoreComment": "warning", + "reportUnnecessaryTypeIgnoreComment": "information", + "disableBytesTypePromotions": false, // TODO: change once Python 3.12 is the minimum "executionEnvironments": [ { // TODO: make this version override work correctly From 6f1d9d71f4c568778a7637ff6582e6f6ba5fb9d3 Mon Sep 17 00:00:00 2001 From: serhii-nakon <57632032+serhii-nakon@users.noreply.github.com> Date: Mon, 30 Sep 2024 21:57:12 +0300 Subject: [PATCH 26/28] Fix Docker ROCM builds, use AMDGPU_TARGETS instead of GPU_TARGETS (#9641) * Fix Docker ROCM builds, use AMDGPU_TARGETS instead of GPU_TARGETS * Set ROCM_DOCKER_ARCH as string due it incorrectly build and cause OOM exit code --- .devops/full-rocm.Dockerfile | 6 +++--- .devops/llama-cli-rocm.Dockerfile | 6 +++--- .devops/llama-server-rocm.Dockerfile | 6 +++--- .github/workflows/build.yml | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.devops/full-rocm.Dockerfile b/.devops/full-rocm.Dockerfile index 680d1cb92205d..df496bcd2b7ee 100644 --- a/.devops/full-rocm.Dockerfile +++ b/.devops/full-rocm.Dockerfile @@ -11,7 +11,7 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build # Unless otherwise specified, we make a fat build. # List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 # This is mostly tied to rocBLAS supported archs. -ARG ROCM_DOCKER_ARCH=\ +ARG ROCM_DOCKER_ARCH="\ gfx803 \ gfx900 \ gfx906 \ @@ -21,7 +21,7 @@ ARG ROCM_DOCKER_ARCH=\ gfx1030 \ gfx1100 \ gfx1101 \ - gfx1102 + gfx1102" COPY requirements.txt requirements.txt COPY requirements requirements @@ -34,7 +34,7 @@ WORKDIR /app COPY . . # Set nvcc architecture -ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} +ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} # Enable ROCm ENV GGML_HIPBLAS=1 ENV CC=/opt/rocm/llvm/bin/clang diff --git a/.devops/llama-cli-rocm.Dockerfile b/.devops/llama-cli-rocm.Dockerfile index c3d1ab06702ec..e60c747bdbf11 100644 --- a/.devops/llama-cli-rocm.Dockerfile +++ b/.devops/llama-cli-rocm.Dockerfile @@ -11,7 +11,7 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build # Unless otherwise specified, we make a fat build. # List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 # This is mostly tied to rocBLAS supported archs. -ARG ROCM_DOCKER_ARCH=\ +ARG ROCM_DOCKER_ARCH="\ gfx803 \ gfx900 \ gfx906 \ @@ -21,7 +21,7 @@ ARG ROCM_DOCKER_ARCH=\ gfx1030 \ gfx1100 \ gfx1101 \ - gfx1102 + gfx1102" COPY requirements.txt requirements.txt COPY requirements requirements @@ -34,7 +34,7 @@ WORKDIR /app COPY . . # Set nvcc architecture -ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} +ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} # Enable ROCm ENV GGML_HIPBLAS=1 ENV CC=/opt/rocm/llvm/bin/clang diff --git a/.devops/llama-server-rocm.Dockerfile b/.devops/llama-server-rocm.Dockerfile index fd0e19ad6e49c..8553af75b61fc 100644 --- a/.devops/llama-server-rocm.Dockerfile +++ b/.devops/llama-server-rocm.Dockerfile @@ -11,7 +11,7 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build # Unless otherwise specified, we make a fat build. # List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 # This is mostly tied to rocBLAS supported archs. -ARG ROCM_DOCKER_ARCH=\ +ARG ROCM_DOCKER_ARCH="\ gfx803 \ gfx900 \ gfx906 \ @@ -21,7 +21,7 @@ ARG ROCM_DOCKER_ARCH=\ gfx1030 \ gfx1100 \ gfx1101 \ - gfx1102 + gfx1102" COPY requirements.txt requirements.txt COPY requirements requirements @@ -34,7 +34,7 @@ WORKDIR /app COPY . . # Set nvcc architecture -ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} +ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} # Enable ROCm ENV GGML_HIPBLAS=1 ENV CC=/opt/rocm/llvm/bin/clang diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e6a977b604d9b..c71d422e70f21 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1032,7 +1032,7 @@ jobs: run: | $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DGPU_TARGETS=${{ matrix.gpu_target }} -DGGML_RPC=ON + cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS=${{ matrix.gpu_target }} -DGGML_RPC=ON cmake --build build -j ${env:NUMBER_OF_PROCESSORS} md "build\bin\rocblas\library\" cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\" From 1927378bcce20ba72b6c89d5977b854a4bcaeb5d Mon Sep 17 00:00:00 2001 From: compilade Date: Tue, 1 Oct 2024 02:31:36 -0400 Subject: [PATCH 27/28] convert : refactor rope_freqs generation (#9396) * convert : refactor rope_freqs generation This should also fix vocab-only conversion for Phi-3. * convert : adapt MiniCPM3 to separate rope_freqs insertion MiniCPM3's tokenizer is treated as a SentencePiece tokenizer to avoid having to run its custom Python code which mixes tokenization in the same file as tool calls. gguf-py : add long and short RoPE factors to tensor mappings Empty, but the key names are used to populate the mappings. --- convert_hf_to_gguf.py | 61 ++++++++++++++++++---------------- convert_lora_to_gguf.py | 5 ++- gguf-py/gguf/constants.py | 4 +++ gguf-py/gguf/tensor_mapping.py | 3 ++ 4 files changed, 44 insertions(+), 29 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f3857d487ca30..da5feb25b1961 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -15,6 +15,7 @@ from pathlib import Path from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast +from itertools import chain import math import numpy as np @@ -64,7 +65,6 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path - is_lora: bool # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -72,7 +72,7 @@ class Model: def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, is_lora: bool = False): + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -94,7 +94,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -270,10 +269,14 @@ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: return False + # some models need extra generated tensors (like rope_freqs) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + return () + def prepare_tensors(self): max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") - for name, data_torch in self.get_tensors(): + for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): # we don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): continue @@ -1617,7 +1620,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] - def prepare_tensors(self): + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) @@ -1644,9 +1647,9 @@ def prepare_tensors(self): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - if not self.is_lora: - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + def prepare_tensors(self): super().prepare_tensors() if self._experts is not None: @@ -1870,8 +1873,6 @@ class MiniCPM3Model(Model): def set_gguf_parameters(self): hparams = self.hparams - rope_dims = hparams["qk_rope_head_dim"] - self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) @@ -1887,24 +1888,25 @@ def set_gguf_parameters(self): self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: rope_scaling = self.find_hparam(['rope_scaling'], True) - if rope_scaling is None: - return + if rope_scaling is not None: + rope_dims = self.hparams["qk_rope_head_dim"] - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) - if long_factors is None or short_factors is None: - raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + if long_factors is None or short_factors is None: + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') - if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: - raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) def set_vocab(self): - self._set_vocab_llama_hf() + self._set_vocab_sentencepiece() def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: @@ -2216,6 +2218,13 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_sliding_window(self.find_hparam(["sliding_window"])) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rope_dims = n_embd // n_head + # write rope scaling for long context (128k) model rope_scaling = self.find_hparam(['rope_scaling'], True) if rope_scaling is None: @@ -2245,9 +2254,8 @@ def set_gguf_parameters(self): if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - if not self.is_lora: - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) @Model.register("PlamoForCausalLM") @@ -4071,7 +4079,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) - def prepare_tensors(self): + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) @@ -4098,10 +4106,7 @@ def prepare_tensors(self): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - if not self.is_lora: - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) - - super().prepare_tensors() + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) @Model.register("GraniteForCausalLM") diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index d1c94e58034b3..439a78de108ca 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -331,6 +331,10 @@ def set_gguf_parameters(self): self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) super().set_gguf_parameters() + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # Never add extra tensors (e.g. rope_freqs) for LoRA adapters + return () + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: tensor_map: dict[str, PartialLoraTensor] = {} @@ -392,7 +396,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter dry_run=args.dry_run, dir_lora_model=dir_lora, lora_alpha=alpha, - is_lora=True, ) logger.info("Exporting model...") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ebe66a4a39f5f..e08617ba240b3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -814,6 +814,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, @@ -892,6 +894,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_Q_A, MODEL_TENSOR.ATTN_Q_B, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e7e9b6fd5efbc..f4a787c56993a 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -87,6 +87,9 @@ class TensorNameMap: "rope.freqs", # llama-pth "rotary_pos_emb.inv_freq", # chatglm ), + + MODEL_TENSOR.ROPE_FACTORS_LONG: (), + MODEL_TENSOR.ROPE_FACTORS_SHORT: (), } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { From a90484c6d9db699bf739d0f33daf1c50cbdd45c9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 1 Oct 2024 11:42:01 +0300 Subject: [PATCH 28/28] llama : print correct model type for Llama 3.2 1B and 3B --- src/llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index c466cd88b7c14..d1d27d21e232f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5502,8 +5502,10 @@ static void llm_load_hparams( } } else { switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_1B; break; // Llama 3.2 1B case 22: model.type = e_model::MODEL_1B; break; case 26: model.type = e_model::MODEL_3B; break; + case 28: model.type = e_model::MODEL_3B; break; // Llama 3.2 3B // granite uses a vocab with len 49152 case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; case 36: model.type = e_model::MODEL_8B; break; // granite