Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Optimization of Layernormalization #103

Merged
merged 9 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bestla/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ project(bestla LANGUAGES CXX VERSION 0.1.0)
file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp)
file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp)

option(BTLA_USE_OPENMP "Enable OpenMP thread pool" ON)
option(BTLA_USE_OPENMP "Enable OpenMP thread pool" OFF)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have a customized threadpool implemented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better not to set it ON as default. It can be set in neural_speed as it uses OMP as default.


option(BTLA_UT_ALL "Enable all unit tests" OFF)
option(BTLA_UT_DEBUG "Enable debug unit tests" OFF)
Expand Down Expand Up @@ -75,6 +75,7 @@ if(UT_BUILD)
add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${ut_headers})
if(BTLA_UT_OPENMP)
include(FindOpenMP)
target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP)
target_link_libraries(${PROJECT_NAME}_ut PRIVATE OpenMP::OpenMP_CXX)
endif()
if(NOT WIN32)
Expand Down
10 changes: 7 additions & 3 deletions bestla/bestla/bestla_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,10 +600,14 @@ class OMPThreading : public IThreading {
public:
explicit OMPThreading(int nthreads) : IThreading(nthreads) { omp_set_num_threads(nthreads); }
void parallel_for(const thread_func& func) const override {
if (mThreadNum > 1) {
#pragma omp parallel
{
int tidx = omp_get_thread_num();
func(tidx);
{
int tidx = omp_get_thread_num();
func(tidx);
}
} else {
func(0);
}
}
virtual void set_threads(int nthreads) override {
Expand Down
103 changes: 103 additions & 0 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,109 @@ static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void
return BTLA_CODE::Success;
}

static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, const float* biasptr, float epsilon,
int norm_size, float* dstptr, float* mean_out, float* mean_square_out,
bool simplified) {
int constexpr VLen = 8;
int norm_size8 = utils::padto_le(norm_size, VLen);
int h = 0;
__m256 vmean = _mm256_setzero_ps(), vmeansq = _mm256_setzero_ps();
for (; h < norm_size8; h += VLen) {
auto tmp = _mm256_loadu_ps(srcptr + h);
vmean = _mm256_add_ps(vmean, tmp);
tmp = _mm256_mul_ps(tmp, tmp);
vmeansq = _mm256_add_ps(vmeansq, tmp);
}
float mean = avx2_reduce_ps<AVX2_REDUCE_TYPE::ADD>(vmean);
float mean_square = avx2_reduce_ps<AVX2_REDUCE_TYPE::ADD>(vmeansq);
for (; h < norm_size; h++) {
mean += srcptr[h];
mean_square += srcptr[h] * srcptr[h];
}
mean = mean / norm_size;
if (simplified) {
mean_square = std::sqrt(mean_square / norm_size + epsilon);
} else {
mean_square = std::sqrt(mean_square / norm_size - mean * mean + epsilon);
}
auto vm = _mm256_set1_ps(mean);
float inv_meansq = 1.f / mean_square;
auto vms = _mm256_set1_ps(inv_meansq);
h = 0;
if (simplified) {
if (scaleptr) {
for (; h < norm_size8; h += VLen) {
auto inp = _mm256_loadu_ps(srcptr + h);
auto scale = _mm256_loadu_ps(scaleptr + h);
inp = _mm256_mul_ps(inp, scale);
inp = _mm256_mul_ps(inp, vms);
_mm256_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = srcptr[h] * inv_meansq * scaleptr[h];
}
} else {
for (; h < norm_size8; h += VLen) {
auto inp = _mm256_loadu_ps(srcptr + h);
inp = _mm256_mul_ps(inp, vms);
_mm256_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = srcptr[h] * inv_meansq;
}
}

} else {
if (scaleptr) {
if (biasptr == nullptr) {
for (; h < norm_size8; h += VLen) {
auto inp = _mm256_loadu_ps(srcptr + h);
auto scale = _mm256_loadu_ps(scaleptr + h);
inp = _mm256_sub_ps(inp, vm);
inp = _mm256_mul_ps(inp, scale);
inp = _mm256_mul_ps(inp, vms);
_mm256_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = (srcptr[h] - mean) * inv_meansq * scaleptr[h];
}
} else {
for (; h < norm_size8; h += VLen) {
auto inp = _mm256_loadu_ps(srcptr + h);
auto scale = _mm256_loadu_ps(scaleptr + h);
inp = _mm256_sub_ps(inp, vm);
inp = _mm256_mul_ps(inp, vms);
inp = _mm256_mul_ps(inp, scale);
auto bias = _mm256_loadu_ps(biasptr + h);
inp = _mm256_add_ps(inp, bias);
_mm256_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = (srcptr[h] - mean) * inv_meansq * scaleptr[h] + biasptr[h];
}
}
} else {
for (; h < norm_size8; h += VLen) {
auto inp = _mm256_loadu_ps(srcptr + h);
inp = _mm256_sub_ps(inp, vm);
inp = _mm256_mul_ps(inp, vms);
_mm256_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = (srcptr[h] - mean) * inv_meansq;
}
}
}

if (mean_out) {
*mean_out = mean;
}
if (mean_square_out) {
*mean_square_out = mean_square;
}
return BTLA_CODE::Success;
}

#ifdef __GNUC__
#pragma GCC pop_options
#else
Expand Down
102 changes: 102 additions & 0 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,108 @@ struct padding_trans_interleave_cvt<utils::fp16, utils::bf16, 2> {
};
#endif

static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, const float* biasptr, float epsilon,
int norm_size, float* dstptr, float* mean_out, float* mean_square_out,
bool simplified) {
int constexpr VLen = 16;
int norm_size16 = utils::padto_le(norm_size, VLen);
int h = 0;
__m512 vmean = _mm512_setzero_ps(), vmeansq = _mm512_setzero_ps();
for (; h < norm_size16; h += VLen) {
auto tmp = _mm512_loadu_ps(srcptr + h);
vmean = _mm512_add_ps(vmean, tmp);
tmp = _mm512_mul_ps(tmp, tmp);
vmeansq = _mm512_add_ps(vmeansq, tmp);
}
float mean = _mm512_reduce_add_ps(vmean);
float mean_square = _mm512_reduce_add_ps(vmeansq);
for (; h < norm_size; h++) {
mean += srcptr[h];
mean_square += srcptr[h] * srcptr[h];
}
mean = mean / norm_size;
if (simplified) {
mean_square = std::sqrt(mean_square / norm_size + epsilon);
} else {
mean_square = std::sqrt(mean_square / norm_size - mean * mean + epsilon);
}
auto vm = _mm512_set1_ps(mean);
float inv_meansq = 1.f / mean_square;
auto vms = _mm512_set1_ps(inv_meansq);
h = 0;
if (simplified) {
if (scaleptr) {
for (; h < norm_size16; h += VLen) {
auto inp = _mm512_loadu_ps(srcptr + h);
auto scale = _mm512_loadu_ps(scaleptr + h);
inp = _mm512_mul_ps(inp, vms);
inp = _mm512_mul_ps(inp, scale);
_mm512_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = srcptr[h] * inv_meansq * scaleptr[h];
}
} else {
for (; h < norm_size16; h += VLen) {
auto inp = _mm512_loadu_ps(srcptr + h);
inp = _mm512_mul_ps(inp, vms);
_mm512_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = srcptr[h] * inv_meansq;
}
}

} else {
if (scaleptr) {
if (biasptr == nullptr) {
for (; h < norm_size16; h += VLen) {
auto inp = _mm512_loadu_ps(srcptr + h);
auto scale = _mm512_loadu_ps(scaleptr + h);
inp = _mm512_sub_ps(inp, vm);
inp = _mm512_mul_ps(inp, vms);
inp = _mm512_mul_ps(inp, scale);
_mm512_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = (srcptr[h] - mean) * inv_meansq * scaleptr[h];
}
} else {
for (; h < norm_size16; h += VLen) {
auto inp = _mm512_loadu_ps(srcptr + h);
auto scale = _mm512_loadu_ps(scaleptr + h);
inp = _mm512_sub_ps(inp, vm);
inp = _mm512_mul_ps(inp, vms);
inp = _mm512_mul_ps(inp, scale);
auto bias = _mm512_loadu_ps(biasptr + h);
inp = _mm512_add_ps(inp, bias);
_mm512_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = (srcptr[h] - mean) * inv_meansq * scaleptr[h] + biasptr[h];
}
}
} else {
for (; h < norm_size16; h += VLen) {
auto inp = _mm512_loadu_ps(srcptr + h);
inp = _mm512_sub_ps(inp, vm);
inp = _mm512_mul_ps(inp, vms);
_mm512_storeu_ps(dstptr + h, inp);
}
for (; h < norm_size; h++) {
dstptr[h] = (srcptr[h] - mean) * inv_meansq;
}
}
}

if (mean_out) {
*mean_out = mean;
}
if (mean_square_out) {
*mean_square_out = mean_square;
}
return BTLA_CODE::Success;
}
#ifdef __GNUC__
#pragma GCC pop_options
#else
Expand Down
55 changes: 55 additions & 0 deletions bestla/bestla/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -1426,6 +1426,61 @@ static inline BTLA_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row,
}
return BTLA_CODE::Success;
}

template <typename T>
static inline BTLA_CODE layernorm(const T* srcptr, const T* scaleptr, const T* biasptr, T epsilon, int norm_size,
T* dstptr, T* mean_out, T* mean_square_out, bool simplified) {
T mean = 0;
T mean_square = 0;

for (int h = 0; h < norm_size; h++) {
mean += srcptr[h];
mean_square += srcptr[h] * srcptr[h];
}

mean = mean / norm_size;
if (simplified) {
mean_square = std::sqrt(mean_square / norm_size + epsilon);
} else {
mean_square = std::sqrt(mean_square / norm_size - mean * mean + epsilon);
}
float inv_mean_square = 1.f / mean_square;
if (simplified) {
if (scaleptr) {
for (int h = 0; h < norm_size; h++) {
dstptr[h] = srcptr[h] * inv_mean_square * scaleptr[h];
}
} else {
for (int h = 0; h < norm_size; h++) {
dstptr[h] = srcptr[h] * inv_mean_square;
}
}
} else {
if (scaleptr) {
if (biasptr == nullptr) {
for (int h = 0; h < norm_size; h++) {
dstptr[h] = (srcptr[h] - mean) * inv_mean_square * scaleptr[h];
}
} else {
for (int h = 0; h < norm_size; h++) {
dstptr[h] = (srcptr[h] - mean) * inv_mean_square * scaleptr[h] + biasptr[h];
}
}
} else {
for (int h = 0; h < norm_size; h++) {
dstptr[h] = (srcptr[h] - mean) * inv_mean_square;
}
}
}

if (mean_out) {
*mean_out = mean;
}
if (mean_square_out) {
*mean_square_out = mean_square;
}
return BTLA_CODE::Success;
}
} // namespace ref
} // namespace kernel
} // namespace bestla
29 changes: 29 additions & 0 deletions bestla/bestla/kernel_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,35 @@ class RemoveZeroPointBias {
}
};

class LayerNormalization {
public:
template <BTLA_ISA ISA_T, typename T>
static inline BTLA_CODE forward(const T* srcptr, const T* scaleptr, const T* biasptr, T epsilon, int norm_size,
T* dstptr, T* mean, T* mean_square, bool simplified) {
if constexpr (utils::isa_base<ISA_T>::avx512f && std::is_same_v<T, float>) {
return avx512f::layernorm(srcptr, scaleptr, biasptr, epsilon, norm_size, dstptr, mean, mean_square, simplified);
}
if constexpr (utils::isa_base<ISA_T>::avx2 && std::is_same_v<T, float>) {
return avx2::layernorm(srcptr, scaleptr, biasptr, epsilon, norm_size, dstptr, mean, mean_square, simplified);
}
return ref::layernorm(srcptr, scaleptr, biasptr, epsilon, norm_size, dstptr, mean, mean_square, simplified);
}
template <typename T>
static inline BTLA_CODE forward_auto(const T* srcptr, const T* scaleptr, const T* biasptr, T epsilon, int norm_size,
T* dstptr, T* mean, T* mean_square, bool simplified) {
GetCPUDevice();
if (_cd->AVX512F()) {
return forward<BTLA_ISA::AVX512F, T>(srcptr, scaleptr, biasptr, epsilon, norm_size, dstptr, mean, mean_square,
simplified);
}
if (_cd->AVX2()) {
return forward<BTLA_ISA::AVX2, T>(srcptr, scaleptr, biasptr, epsilon, norm_size, dstptr, mean, mean_square,
simplified);
}
return forward<BTLA_ISA::NoSIMD, T>(srcptr, scaleptr, biasptr, epsilon, norm_size, dstptr, mean, mean_square,
simplified);
}
};
} // namespace wrapper
} // namespace kernel
} // namespace bestla
34 changes: 34 additions & 0 deletions bestla/bestla/ut/kernel_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,39 @@ class UT_RevertPaddingInterleaveMN {
#ifdef BTLA_UT_KERNEL_WRAPPER
static UT_RevertPaddingInterleaveMN sUT_RevertPaddingInterleaveMN;
#endif

class UT_LayerNormalization {
public:
UT_LayerNormalization() {
UT_START();
ut<float, BTLA_ISA::AVX512F>(4096, false, true, true);
ut<float, BTLA_ISA::AVX512F>(4096, false, false, false);
ut<float, BTLA_ISA::AVX512F>(111, false, true, true);
ut<float, BTLA_ISA::AVX512F>(111, true, true, true);
ut<float, BTLA_ISA::AVX2>(4096, false, true, true);
ut<float, BTLA_ISA::AVX2>(4096, false, false, false);
ut<float, BTLA_ISA::AVX2>(111, false, true, true);
ut<float, BTLA_ISA::AVX2>(111, true, true, true);
}
template <typename T, BTLA_ISA ISA>
void ut(int norm_size, bool simplified, bool hasscale, bool hasbias) {
printf("%s %d\n", __FUNCTION__, norm_size);
aligned_vector<T> src(norm_size), dst(norm_size), bias(norm_size), scale(norm_size), ref(norm_size);
fill_buffer_randn(src.data(), src.size(), -0.5f, 0.5f);
fill_buffer_randn(bias.data(), bias.size(), -0.5f, 0.5f);
fill_buffer_randn(scale.data(), scale.size(), 0.1f, 1.f);
T mean = 0.f, mean_square = 0.f;
kernel::wrapper::LayerNormalization::forward<BTLA_ISA::NoSIMD>(src.data(), hasscale ? scale.data() : nullptr,
hasbias ? bias.data() : nullptr, 0.00001f, norm_size,
ref.data(), &mean, &mean_square, simplified);
kernel::wrapper::LayerNormalization::forward<ISA>(src.data(), hasscale ? scale.data() : nullptr,
hasbias ? bias.data() : nullptr, 0.00001f, norm_size, dst.data(),
&mean, &mean_square, simplified);
buffer_error(ref.data(), dst.data(), ref.size(), 0.01f);
}
};
#ifdef BTLA_UT_KERNEL_WRAPPER
UT_LayerNormalization sUT_LayerNormalization;
#endif
} // namespace ut
} // namespace bestla
Loading
Loading