From 1bdcc988134bae124d17199599c8d5dc4bb66d49 Mon Sep 17 00:00:00 2001 From: Viviane Potocnik Date: Wed, 31 Jan 2024 01:12:41 +0100 Subject: [PATCH] lint: FA-2 and GEMM --- sw/blas/gemm/src/gemm.h | 42 ++++++++++--------- .../flashattention_2/src/flashattention_2.h | 9 ++-- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/sw/blas/gemm/src/gemm.h b/sw/blas/gemm/src/gemm.h index 2d313f65db..715fb2f5a1 100644 --- a/sw/blas/gemm/src/gemm.h +++ b/sw/blas/gemm/src/gemm.h @@ -1089,29 +1089,31 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr, switch (prec) { case FP64: - if (BASELINE == 1) { - gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA, - lda_strided, transa, (double*)b, ldb, transb, - (double*)c + offsetC, ldc_strided, - (double)beta); - } else { - gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, - lda_strided, transa, (double*)b, ldb, transb, - (double*)c + offsetC, ldc_strided, &beta, setup_ssr); - } + if (BASELINE == 1) { + gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA, + lda_strided, transa, (double*)b, ldb, + transb, (double*)c + offsetC, + ldc_strided, (double)beta); + } else { + gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, + lda_strided, transa, (double*)b, ldb, transb, + (double*)c + offsetC, ldc_strided, &beta, + setup_ssr); + } break; case FP32: - if (BASELINE == 1) { + if (BASELINE == 1) { gemm_fp32_baseline(frac_m, n, k, (float*)a + offsetA, - lda_strided, transa, (float*)b, ldb, transb, - (float*)c + offsetC, ldc_strided, - (float)beta); - } else { - gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided, - (float*)b, ldb, (float*)c + offsetC, ldc_strided, - &beta, setup_ssr); - } - + lda_strided, transa, (float*)b, ldb, + transb, (float*)c + offsetC, ldc_strided, + (float)beta); + } else { + gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, + lda_strided, (float*)b, ldb, + (float*)c + offsetC, ldc_strided, &beta, + setup_ssr); + } + break; case FP16: if (expand) { diff --git a/sw/dnn/flashattention_2/src/flashattention_2.h b/sw/dnn/flashattention_2/src/flashattention_2.h index 4dd9503381..18655f2262 100644 --- a/sw/dnn/flashattention_2/src/flashattention_2.h +++ b/sw/dnn/flashattention_2/src/flashattention_2.h @@ -275,14 +275,15 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) { // Calculate P tile as the "local" softmax of S for (int col_idx = 0; col_idx < B_c; col_idx++) { - P_fa[row_idx * B_c + col_idx] = - double_dummy_exp(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]); - // expf(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]); + P_fa[row_idx * B_c + col_idx] = double_dummy_exp( + S_fa[row_idx * B_c + col_idx] - m_i[row_idx]); + // expf(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]); row_sum += P_fa[row_idx * B_c + col_idx]; } // Calculate rescaling factor l - shifted_exp = double_dummy_exp(m_i_prev[row_idx] - m_i[row_idx]); + shifted_exp = + double_dummy_exp(m_i_prev[row_idx] - m_i[row_idx]); if (t_c != 0) { l_i[row_idx] = l_i[row_idx] * shifted_exp + row_sum; } else {