-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] Support dynamic activation sparsity #27974
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new function is really cool to simplify writing jit code: automatically mapping physical register, simulate c-like control flow, expressions for generic register. Maybe these features could be extended to simd register in the future, then we can write jit just like intrinsic!
MemoryPtr m_scales; | ||
ActSparseFCNode::Config& m_config; | ||
|
||
void show(const char* name, uint8_t* src, int stride, int rows, int cols) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will remove
template <class T> | ||
T* ActSparseFcKernel::scratch_alloc(size_t cnt) { | ||
# if defined(__GNUC__) || defined(__clang__) | ||
thread_local uint8_t scratch[1024 * 1024 * 2] __attribute__((aligned(4096))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd better to use the scratch in the GraphContext.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
int OC, | ||
int n0, | ||
int n1) { | ||
static auto repack_2xsimdw = jit_compile_repack_2xsimdw(WeightCompressionType::FP16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd better to add to the primitive cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added primitive cache support
auto jit = std::make_shared<SIMDJit>(__func__); | ||
auto simd_width = SIMDJit::vmm_width<float>(); | ||
|
||
auto zp_input_u8 = jit->get_sreg(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input parameters of a jit function are different from normal variables, we'd better to add a new function to get the input parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will add get_arg()
const float* scales, | ||
const uint8_t* zp) { | ||
const auto SIMDW = SIMDJit::vmm_width<float>(); | ||
if (OC % (2 * SIMDW)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OPENVINO_ASSERT
should be simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will change
return jit; | ||
} | ||
|
||
static void gemm6x2_Mx2(const float* pA, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the meaning of x2 in Mx2
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x2 means 2 SIMD register width in unit of fp32, for example in AVX2 cases, x2 means 16 fp32, in AVX512, it means 32 fp32s
Details:
Activation sparsity exploit the fact that activations in MLP of LLMs is sparse and input channels of activations with small magnitude can be set as zero with acceptable accuracy-drop.
The distribution of sparse channels of activation is dynamic (only known at runtime) and variates a lot from token to token, thus the optimization opportunity only exists in 2nd token generation process with batch-size fixed to 1 (which is exactly typical use-case for client-side LLM inference), in which case weight memory reading cost corresponding to the skipped input channel can be saved.
The best weight memory layout for this optimization is plain [IC, OC], so weights corresponding to each input channel is dense, the non-sparse input channel can enjoy CPU's HW-prefetcher's boost to continuous stream access. if we use current blocked weight-layout set by oneDNN-fork, the weights from both non-sparse & sparse channels would be mixed together in unit of cache-line, which would hurt performance, both due to unfriendly access pattern to HW-prefetcher & DDR's physical page granularity.
But choose plain [IC,OC] layout poses challenge to 1st token latency because blocked layout is best for 1st-token/compute-bound case, so in this PR, we have to also minimize the degradation of 1st token latency.
Peformance data on i9-13900K
we can see that there is no regression in 1st token latency, and ~20% reduction in 2nd token latency.
SIMDJit
In this PR we introduced a new way of writing JIT kernels, which is an enhanced version of existing attempts to making JIT programing more friendly :
these efforts are all focusing on making xbyak based JIT programming a more user friendly as an EDSL, and in this PR I go further along this direction:
In future, we can port it to ARM64 & RISC-V, and also try to do another level of abstraction on SIMD vector register to (maybe) make single kernel working on all CPU platform with much less efforts (in term of implementing & porting)
Tickets: