From d5bd7f405370c43485f80058b8cc1d6d0fc3a55e Mon Sep 17 00:00:00 2001 From: Luca Colagrande Date: Thu, 8 Feb 2024 16:25:43 +0100 Subject: [PATCH] Fix linting --- .../flashattention_2/src/flashattention_2.h | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sw/dnn/flashattention_2/src/flashattention_2.h b/sw/dnn/flashattention_2/src/flashattention_2.h index 700f74f810..a7279719d9 100644 --- a/sw/dnn/flashattention_2/src/flashattention_2.h +++ b/sw/dnn/flashattention_2/src/flashattention_2.h @@ -109,9 +109,9 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) { snrt_dma_txid_t txid_q_fa = snrt_dma_start_2d(Q_fa, /* dst */ Q_l3 + q_fa_offset, /* src */ - d * sizeof(float), /* size */ - d * sizeof(float), /* dst_stride */ - d * sizeof(float), /* src_stride */ + d * sizeof(float), /* size */ + d * sizeof(float), /* dst_stride */ + d * sizeof(float), /* src_stride */ B_r); /* repetitions */ snrt_dma_wait_all(); @@ -149,19 +149,19 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) { if (!snrt_is_compute_core()) { // K is in (d, N) format in main memory snrt_dma_txid_t txid_k_fa = - snrt_dma_start_2d(K_fa, /* dst */ - K_l3 + k_fa_offset, /* src */ + snrt_dma_start_2d(K_fa, /* dst */ + K_l3 + k_fa_offset, /* src */ B_c * sizeof(float), /* size */ B_c * sizeof(float), /* dst_stride */ N * sizeof(float), /* src_stride */ - d); /* repetitions */ + d); /* repetitions */ snrt_dma_txid_t txid_v_fa = snrt_dma_start_2d(V_fa, /* dst */ V_l3 + v_fa_offset, /* src */ - d * sizeof(float), /* size */ - d * sizeof(float), /* dst_stride */ - d * sizeof(float), /* src_stride */ + d * sizeof(float), /* size */ + d * sizeof(float), /* dst_stride */ + d * sizeof(float), /* src_stride */ B_c); /* repetitions */ snrt_dma_wait_all(); @@ -176,8 +176,8 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) { // column block of K to calculate a tile of S: S = Q * K^T. // The S tile is of form (B_r, B_c) uint32_t start_gemm = snrt_mcycle(); - sc_st_gemm(dtype, 0, 0, 0, 0, B_r, B_c, d, 1, Q_fa, d, K_fa, B_c, - 0, S_fa, B_c, baseline); + sc_st_gemm(dtype, 0, 0, 0, 0, B_r, B_c, d, 1, Q_fa, d, K_fa, + B_c, 0, S_fa, B_c, baseline); uint32_t end_gemm = snrt_mcycle(); snrt_cluster_hw_barrier(); @@ -200,13 +200,13 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) { // Calculate P tile as the "local" softmax of S for (int col_idx = 0; col_idx < B_c; col_idx++) { - P_fa[row_idx * B_c + col_idx] = expf(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]); + P_fa[row_idx * B_c + col_idx] = + expf(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]); row_sum += P_fa[row_idx * B_c + col_idx]; } // Calculate rescaling factor l - shifted_exp = - expf(m_i_prev[row_idx] - m_i[row_idx]); + shifted_exp = expf(m_i_prev[row_idx] - m_i[row_idx]); if (t_c != 0) { l_i[row_idx] = l_i[row_idx] * shifted_exp + row_sum; } else { @@ -274,9 +274,9 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) { snrt_dma_txid_t txid_o_fa = snrt_dma_start_2d(O_l3 + o_fa_offset, /* dst */ O_fa, /* src */ - d * sizeof(float), /* size */ - d * sizeof(float), /* dst_stride */ - d * sizeof(float), /* src_stride */ + d * sizeof(float), /* size */ + d * sizeof(float), /* dst_stride */ + d * sizeof(float), /* src_stride */ B_r); /* repetitions */ snrt_dma_wait_all();