Skip to content

Commit

Permalink
Add flash attention on intel-avx512 platform (#18)
Browse files Browse the repository at this point in the history
* fix: tweak glm-4-9b-chat model

* feat: add llama3 model

* feat: add flash attention support at intel-avx512 platform

---------

Co-authored-by: jinyejun.jyj <[email protected]>
  • Loading branch information
yejunjin and yejunjin authored Jun 21, 2024
1 parent 5ebd2c0 commit 3a0417b
Show file tree
Hide file tree
Showing 25 changed files with 1,096 additions and 59 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ set(CONFIG_HOST_CPU_TYPE "X86" CACHE STRING "host cpu type, like X86, ARMV9, etc

## x86 related option.
option(ENABLE_AVX2 "enable avx2" ON)
option(ENABLE_AVX512 "enable avx512" ON)

## ARM related option.
option(ENABLE_ARMCL "enable use of Arm Compute Library" OFF)
Expand All @@ -55,6 +56,10 @@ if(ENABLE_AVX2)
list(APPEND ALLSPARK_DEFINITION "-DENABLE_AVX2")
endif()

if(ENABLE_AVX512)
list(APPEND ALLSPARK_DEFINITION "-DENABLE_AVX512")
endif()

if(ENABLE_ARM_V84_V9)
list(APPEND ALLSPARK_DEFINITION "-DENABLE_ARM_V84_V9")
if (ENABLE_BF16)
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Written in C++ runtime, DashInfer aims to deliver production-level implementatio
- **Support for Mainstream Open-Source LLMs**: DashInfer supports mainstream open-source LLMs, including Qwen, LLaMA, ChatGLM, etc., and supports loading models in the Huggingface format.
- **Post Training Quantization (PTQ)**: Using DashInfer's InstantQuant (IQ), weight-only quantization acceleration can be achieved without fine-tuning, improving deployment efficiency. Accuracy evaluation shows that IQ has no impact on model accuracy. The current version supports weight-only 8-bit quantization on ARM CPUs.
- **Optimized Computation Kernels**: With OneDNN and self-developed assembly kernels, DashInfer is able to maximize the performance of the hardware on both ARM and x86.
- **Supports Flash Attention**: Significantly accelerates the attention computation for long sequences, drastically reducing the latency for the first-token.
- **NUMA-Aware Design**: DashInfer supports tensor parallel inference across multiple NUMA nodes, fully leveraging the computing power of server CPUs. With numactl and a multi-process architecture, the NUMA affinity of threads is accurately controlled to maximize the performance of multi-node CPUs and avoid the performance degradation caused by cross-NUMA access. For more information on NUMA, see: [Optimizing Applications for NUMA - Intel](https://www.intel.com/content/dam/develop/external/us/en/documents/3-5-memmgt-optimizing-applications-for-numa-184398.pdf), [What is NUMA?](https://www.kernel.org/doc/html/v5.0/vm/numa.html).
- **Context Length**: The current version supports up to 32k context length, with plans to extend to longer context lengths in the future.
- **Multi-Language API Interfaces**: Both C++ and Python interfaces are supported. It is possible to extend C++ interface to Java, Rust and other programming languages, via standard cross-language interfaces.
Expand Down Expand Up @@ -88,6 +89,7 @@ During inference, the quantized weight is recovered as bfloat16 for matrix multi
| ChatGLMModel | ChatGLM | ChatGLM_v3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b),<br>[THUDM/chatglm3-6b-32k](https://huggingface.co/THUDM/chatglm3-6b-32k) | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary),<br>[ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k/summary) | / |
| ChatGLMModel | ChatGLM | ChatGLM_v4 | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) | [dash-infer/glm-4-9b-chat-DI](https://modelscope.cn/models/dash-infer/glm-4-9b-chat-DI/summary) |
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) | / |
| LlamaForCausalLM | LLaMA-3 | LLaMA_v3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | [modelscope/Meta-Llama-3-8B-Instruct](https://modelscope.cn/models/modelscope/Meta-Llama-3-8B-Instruct/summary) | / |

# Software Architecture

Expand Down Expand Up @@ -187,7 +189,7 @@ This subsection lists the third-party dependencies for the different stages of D

# Future Plans

- [ ] Accelerate attention with Flash-Attention
- [x] Accelerate attention with Flash-Attention
- [x] Expand context length to over 32k
- [ ] Support 4-bit quantization
- [ ] Support quantized models fine-tuned with GPTQ
Expand Down
5 changes: 4 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ DashInfer采用C++ Runtime编写,提供C++和Python语言接口。DashInfer具
- **支持主流LLM开源模型**:支持主流的开源LLM模型,包括Qwen、LLaMA、ChatGLM等,支持Huggingface格式的模型读取。
- **PTQ量化**:使用DashInfer的InstantQuant(IQ),无需训练微调即可实现weight-only量化加速,提高部署效率。经过精度测试,IQ对模型精度不会产生影响。目前版本支持ARM CPU上的weight-only 8-bit量化。
- **优化的计算Kernel**:结合OneDNN和自研汇编kernel,DashInfer能够在ARM和x86上发挥硬件的最大性能。
- **支持Flash Attention**:显著加速了长序列的Attention计算过程,大大降低首包延迟。
- **NUMA-Aware**:支持多NUMA的tensor并行推理,充分发挥服务器级CPU的算力。通过numactl和多进程架构,精准控制计算线程的NUMA亲和性,充分利用多节点CPU的性能,并且避免跨NUMA访存带来性能下降问题。关于多NUMA的性能指导可以参考:[Optimizing Applications for NUMA - Intel](https://www.intel.com/content/dam/develop/external/us/en/documents/3-5-memmgt-optimizing-applications-for-numa-184398.pdf), [What is NUMA?](https://www.kernel.org/doc/html/v5.0/vm/numa.html)
- **上下文长度(Context Length)**:目前版本支持32k的Context Length,未来还会继续支持更长Context Length。
- **提供多语言API接口**:提供C++和Python接口,能够直接使用C++接口对接到Java、Rust等其他编程语言。
Expand Down Expand Up @@ -89,6 +90,7 @@ $$ x_{u8} = x_{fp32} / scale + zeropoint $$
| ChatGLMModel | ChatGLM | ChatGLM_v3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b),<br>[THUDM/chatglm3-6b-32k](https://huggingface.co/THUDM/chatglm3-6b-32k) | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary),<br>[ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k/summary) | / |
| ChatGLMModel | ChatGLM | ChatGLM_v4 | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) | [dash-infer/glm-4-9b-chat-DI](https://modelscope.cn/models/dash-infer/glm-4-9b-chat-DI/summary) |
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) | / |
| LlamaForCausalLM | LLaMA-3 | LLaMA_v3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | [modelscope/Meta-Llama-3-8B-Instruct](https://modelscope.cn/models/modelscope/Meta-Llama-3-8B-Instruct/summary) | / |

# 软件框架

Expand Down Expand Up @@ -188,12 +190,13 @@ $$ x_{u8} = x_{fp32} / scale + zeropoint $$

# 未来规划

- [ ] 首包加速:加入CPU实现的Flash-Attention等Attention加速技术;
- [x] 首包加速:加入CPU实现的Flash-Attention等Attention加速技术;
- [x] Context Length:扩展到32k以上;
- [ ] 低bit量化支持:支持4-bit量化;
- [ ] QAT量化支持:支持GPTQ算法量化微调过的模型;
- [ ] MoE:支持MoE模型和架构。


# License

DashInfer源代码采用Apache 2.0协议授权,您可在该仓库根目录找到协议全文。
1 change: 1 addition & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ elif [ "${with_platform,,}" == "armclang" ]; then
-DBUILD_PACKAGE=${build_package} \
-DALLSPARK_CBLAS=BLIS \
-DENABLE_AVX2=OFF \
-DENABLE_AVX512=OFF \
-DENABLE_ARM_V84_V9=ON \
-DENABLE_BF16=ON \
-DENABLE_FP16=ON \
Expand Down
11 changes: 11 additions & 0 deletions csrc/common/env_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,15 @@ class EnvVarConfig {
}
};

class AttentionEnvConfig {
public:
static int GetFlashThresh() {
static int env_flash_thresh = -1;
if (env_flash_thresh == -1) {
env_flash_thresh = EnvVarConfig::GetInt("AS_FLASH_THRESH", 1024);
}
return env_flash_thresh;
}
};

} // namespace allspark
14 changes: 14 additions & 0 deletions csrc/core/kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ file(
cpu/rotary.cpp
)

file(
GLOB_RECURSE
src_avx512
cpu/mha.cpp
)

file(
GLOB_RECURSE
src_arm
Expand All @@ -53,6 +59,14 @@ if(ENABLE_AVX2)
set_source_files_properties(${src_avx2} PROPERTIES COMPILE_FLAGS "${AVX2_FLAGS}")
endif(ENABLE_AVX2)

if(ENABLE_AVX512)
set(AVX512_FLAGS "-mavx512f -mavx512bw -mavx512vl")
message("AVX512 flags: ${AVX512_FLAGS}, files: ${src_avx512}")
get_source_file_property(OTHER_FLAGS ${src_avx512} COMPILE_FLAGS)
message("APPEND flags: ${OTHER_FLAGS}, files: ${src_avx512}")
set_source_files_properties(${src_avx512} PROPERTIES COMPILE_FLAGS "${OTHER_FLAGS} ${AVX512_FLAGS}")
endif(ENABLE_AVX512)

if(NOT ENABLE_ARM_V84_V9)
foreach(file ${src_arm})
list(REMOVE_ITEM src_cpu_common "${file}")
Expand Down
10 changes: 10 additions & 0 deletions csrc/core/kernel/cpu/cpu_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ void EmbeddingT5KernelLauncher(T* out_tensor, const int64_t* word_ids,
const T* embedding_table, int batch_size,
int seq_len, int hidden_size, int vocab_size,
bool use_decoder);

template <typename T>
void SelfScaledDpAttention(T* output, const T* query, const T* key,
const T* value, int q_num_heads, int kv_num_heads,
int size_per_head, int o_stride, int q_stride,
int kv_stride, int batch_size,
const int* input_seq_lens, const int* past_seq_lens,
void* workspace, int src_blk, int tgt_blk,
const float* mask, float scale, int num_thread);

template <typename T>
void GetBatchArrayLauncher(T* q, T* k, T* v, T* score, T* out, T** q_array,
T** k_array, T** v_array, T** score_array,
Expand Down
Loading

0 comments on commit 3a0417b

Please sign in to comment.