Skip to content

Commit

Permalink
Fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Feb 8, 2024
1 parent 917d638 commit d5bd7f4
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions sw/dnn/flashattention_2/src/flashattention_2.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) {
snrt_dma_txid_t txid_q_fa =
snrt_dma_start_2d(Q_fa, /* dst */
Q_l3 + q_fa_offset, /* src */
d * sizeof(float), /* size */
d * sizeof(float), /* dst_stride */
d * sizeof(float), /* src_stride */
d * sizeof(float), /* size */
d * sizeof(float), /* dst_stride */
d * sizeof(float), /* src_stride */
B_r); /* repetitions */

snrt_dma_wait_all();
Expand Down Expand Up @@ -149,19 +149,19 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) {
if (!snrt_is_compute_core()) {
// K is in (d, N) format in main memory
snrt_dma_txid_t txid_k_fa =
snrt_dma_start_2d(K_fa, /* dst */
K_l3 + k_fa_offset, /* src */
snrt_dma_start_2d(K_fa, /* dst */
K_l3 + k_fa_offset, /* src */
B_c * sizeof(float), /* size */
B_c * sizeof(float), /* dst_stride */
N * sizeof(float), /* src_stride */
d); /* repetitions */
d); /* repetitions */

snrt_dma_txid_t txid_v_fa =
snrt_dma_start_2d(V_fa, /* dst */
V_l3 + v_fa_offset, /* src */
d * sizeof(float), /* size */
d * sizeof(float), /* dst_stride */
d * sizeof(float), /* src_stride */
d * sizeof(float), /* size */
d * sizeof(float), /* dst_stride */
d * sizeof(float), /* src_stride */
B_c); /* repetitions */

snrt_dma_wait_all();
Expand All @@ -176,8 +176,8 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) {
// column block of K to calculate a tile of S: S = Q * K^T.
// The S tile is of form (B_r, B_c)
uint32_t start_gemm = snrt_mcycle();
sc_st_gemm(dtype, 0, 0, 0, 0, B_r, B_c, d, 1, Q_fa, d, K_fa, B_c,
0, S_fa, B_c, baseline);
sc_st_gemm(dtype, 0, 0, 0, 0, B_r, B_c, d, 1, Q_fa, d, K_fa,
B_c, 0, S_fa, B_c, baseline);
uint32_t end_gemm = snrt_mcycle();

snrt_cluster_hw_barrier();
Expand All @@ -200,13 +200,13 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) {

// Calculate P tile as the "local" softmax of S
for (int col_idx = 0; col_idx < B_c; col_idx++) {
P_fa[row_idx * B_c + col_idx] = expf(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]);
P_fa[row_idx * B_c + col_idx] =
expf(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]);
row_sum += P_fa[row_idx * B_c + col_idx];
}

// Calculate rescaling factor l
shifted_exp =
expf(m_i_prev[row_idx] - m_i[row_idx]);
shifted_exp = expf(m_i_prev[row_idx] - m_i[row_idx]);
if (t_c != 0) {
l_i[row_idx] = l_i[row_idx] * shifted_exp + row_sum;
} else {
Expand Down Expand Up @@ -274,9 +274,9 @@ static inline void flashattention_2_layer(flashattention_2_layer_t layer) {
snrt_dma_txid_t txid_o_fa =
snrt_dma_start_2d(O_l3 + o_fa_offset, /* dst */
O_fa, /* src */
d * sizeof(float), /* size */
d * sizeof(float), /* dst_stride */
d * sizeof(float), /* src_stride */
d * sizeof(float), /* size */
d * sizeof(float), /* dst_stride */
d * sizeof(float), /* src_stride */
B_r); /* repetitions */

snrt_dma_wait_all();
Expand Down

0 comments on commit d5bd7f4

Please sign in to comment.