Skip to content

Commit

Permalink
lint: FA-2 and GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Jan 31, 2024
1 parent 3ecc31b commit 1bdcc98
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
42 changes: 22 additions & 20 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -1089,29 +1089,31 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,

switch (prec) {
case FP64:
if (BASELINE == 1) {
gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA,
lda_strided, transa, (double*)b, ldb, transb,
(double*)c + offsetC, ldc_strided,
(double)beta);
} else {
gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA,
lda_strided, transa, (double*)b, ldb, transb,
(double*)c + offsetC, ldc_strided, &beta, setup_ssr);
}
if (BASELINE == 1) {
gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA,
lda_strided, transa, (double*)b, ldb,
transb, (double*)c + offsetC,
ldc_strided, (double)beta);
} else {
gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA,
lda_strided, transa, (double*)b, ldb, transb,
(double*)c + offsetC, ldc_strided, &beta,
setup_ssr);
}
break;
case FP32:
if (BASELINE == 1) {
if (BASELINE == 1) {
gemm_fp32_baseline(frac_m, n, k, (float*)a + offsetA,
lda_strided, transa, (float*)b, ldb, transb,
(float*)c + offsetC, ldc_strided,
(float)beta);
} else {
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided,
(float*)b, ldb, (float*)c + offsetC, ldc_strided,
&beta, setup_ssr);
}

lda_strided, transa, (float*)b, ldb,
transb, (float*)c + offsetC, ldc_strided,
(float)beta);
} else {
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA,
lda_strided, (float*)b, ldb,
(float*)c + offsetC, ldc_strided, &beta,
setup_ssr);
}

break;
case FP16:
if (expand) {
Expand Down
9 changes: 5 additions & 4 deletions sw/dnn/flashattention_2/src/flashattention_2.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,15 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) {

// Calculate P tile as the "local" softmax of S
for (int col_idx = 0; col_idx < B_c; col_idx++) {
P_fa[row_idx * B_c + col_idx] =
double_dummy_exp(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]);
// expf(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]);
P_fa[row_idx * B_c + col_idx] = double_dummy_exp(
S_fa[row_idx * B_c + col_idx] - m_i[row_idx]);
// expf(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]);
row_sum += P_fa[row_idx * B_c + col_idx];
}

// Calculate rescaling factor l
shifted_exp = double_dummy_exp(m_i_prev[row_idx] - m_i[row_idx]);
shifted_exp =
double_dummy_exp(m_i_prev[row_idx] - m_i[row_idx]);
if (t_c != 0) {
l_i[row_idx] = l_i[row_idx] * shifted_exp + row_sum;
} else {
Expand Down

0 comments on commit 1bdcc98

Please sign in to comment.