From 9e23aacee8b2184affb99b3531e6ed9a0d8e2435 Mon Sep 17 00:00:00 2001 From: Luca Colagrande Date: Tue, 31 Oct 2023 18:20:58 +0100 Subject: [PATCH] dnn: Add FlashAttention-2 layer --- sw/blas/gemm/src/gemm.h | 117 +++--- sw/dnn/flashattention_2/data/datagen.py | 175 ++++++++ sw/dnn/flashattention_2/data/params.hjson | 11 + .../flashattention_2/src/flashattention_2.h | 385 ++++++++++++++++++ sw/dnn/flashattention_2/src/main.c | 14 + sw/dnn/flashattention_2/verify.py | 94 +++++ sw/dnn/layernorm/src/layernorm.h | 3 +- sw/dnn/requirements.txt | 2 +- sw/dnn/src/dnn.h | 1 + target/snitch_cluster/sw/apps/Makefile | 1 + .../sw/apps/dnn/flashattention_2/Makefile | 12 + target/snitch_cluster/sw/run.yaml | 3 + .../sw/runtime/banshee/src/dump.h | 28 -- target/snitch_cluster/sw/toolchain.mk | 2 +- 14 files changed, 769 insertions(+), 79 deletions(-) create mode 100755 sw/dnn/flashattention_2/data/datagen.py create mode 100644 sw/dnn/flashattention_2/data/params.hjson create mode 100644 sw/dnn/flashattention_2/src/flashattention_2.h create mode 100644 sw/dnn/flashattention_2/src/main.c create mode 100755 sw/dnn/flashattention_2/verify.py create mode 100644 target/snitch_cluster/sw/apps/dnn/flashattention_2/Makefile delete mode 100644 target/snitch_cluster/sw/runtime/banshee/src/dump.h diff --git a/sw/blas/gemm/src/gemm.h b/sw/blas/gemm/src/gemm.h index ea2c865636..3bd57aa79f 100644 --- a/sw/blas/gemm/src/gemm.h +++ b/sw/blas/gemm/src/gemm.h @@ -21,8 +21,18 @@ typedef __fp16 v4f16 __attribute__((vector_size(8))); typedef char v8f8 __attribute__((vector_size(8))); #endif -dump_float(gemm, 8); -dump_uint(index, 9); +// Floating-point multiplications by zero cannot be optimized as in some +// edge cases they do not yield zero: +// - 0f * NaN = NaN +// - 0f * INFINITY == NaN +// Thus in order to optimize it, we need to test for zero. You can use this +// function for free when `multiplier` is a constant. +static inline double multiply_opt(double multiplicand, double multiplier) { + if (multiplier) + return multiplicand * multiplier; + else + return 0; +} void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, uint32_t ta, double* B, uint32_t ldB, @@ -30,10 +40,15 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, if (!ta && !tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - double c0 = BETA * C[m * ldC + n]; + double c0 = multiply_opt(C[m * ldC + n], BETA); for (uint32_t k = 0; k < K; k++) { // dump_index(k + m * ldA); // dump_gemm(A[k + m * ldA]); + + // if (snrt_cluster_core_idx() == 7) { + // printf("k = %d, m = %d, n = %d, ldA = %d, ldB = + // %d\n", k, m, n, ldA, ldB); + // } c0 += A[k + m * ldA] * B[k * ldB + n]; } C[m * ldC + n] = c0; @@ -42,7 +57,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, } else if (ta && !tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - register double c0 = BETA * C[m * ldC + n]; + double c0 = multiply_opt(C[m * ldC + n], BETA); for (uint32_t k = 0; k < K; k++) { c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n]; } @@ -52,7 +67,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, } else if (!ta && tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - register double c0 = BETA * C[m * ldC + n]; + double c0 = multiply_opt(C[m * ldC + n], BETA); for (uint32_t k = 0; k < K; k++) { c0 += A[k + m * ldA] * B[k + n * ldB]; } @@ -62,7 +77,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, } else { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - register double c0 = BETA * C[m * ldC + n]; + double c0 = multiply_opt(C[m * ldC + n], BETA); for (uint32_t k = 0; k < K; k++) { c0 += A[k * M * ldA + m * ldA] * B[k + n * ldB]; } @@ -999,46 +1014,54 @@ void gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr, uint32_t transa, uint32_t transb, uint32_t m, uint32_t n, uint32_t k, double alpha, void* a, uint32_t lda, void* b, uint32_t ldb, uint32_t beta, void* c, uint32_t ldc) { - const uint32_t compute_num = snrt_cluster_compute_core_num(); - const uint32_t compute_id = snrt_cluster_core_idx(); - - // Compute cores work not on contiguous blocks but on strided rows - uint32_t lda_strided = compute_num * lda; - uint32_t ldc_strided = compute_num * ldc; - - // Compute cores access A and C at offsets of one row from each other - uint32_t offsetA = compute_id * lda; - uint32_t offsetC = compute_id * ldc; - - // Compute fraction of C rows every core computes - uint32_t frac_m = m / compute_num; - - switch (prec) { - case FP64: - gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided, - transa, (double*)b, ldb, transb, (double*)c + offsetC, - ldc_strided, &beta, setup_ssr); - break; - case FP32: - gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided, - (float*)b, ldb, (float*)c + offsetC, ldc_strided, - &beta, setup_ssr); - break; - case FP16: - if (expand) { - gemm_fp16_ex_opt( - frac_m, n, k, (__fp16*)a + offsetA, lda_strided, (__fp16*)b, - ldb, (__fp16*)c + offsetC, ldc_strided, &beta, setup_ssr); - } else { - gemm_fp16_opt(frac_m, n, k, (__fp16*)a + offsetA, lda_strided, - (__fp16*)b, ldb, (__fp16*)c + offsetC, - ldc_strided, &beta, setup_ssr); - } - break; - case FP8: - gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, (char*)b, - ldb, (char*)c + offsetC, ldc_strided, &beta, - setup_ssr); - break; + if (snrt_is_compute_core()) { + const uint32_t compute_num = snrt_cluster_compute_core_num(); + const uint32_t compute_id = snrt_cluster_core_idx(); + + // Compute cores work not on contiguous blocks but on strided rows + uint32_t lda_strided = compute_num * lda; + uint32_t ldc_strided = compute_num * ldc; + + // Compute cores access A and C at offsets of one row from each other + uint32_t offsetA = compute_id * lda; + uint32_t offsetC = compute_id * ldc; + + // Compute fraction of C rows every core computes + uint32_t frac_m = m / compute_num; + + switch (prec) { + case FP64: + // gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, + // lda_strided, + // transa, (double*)b, ldb, transb, (double*)c + + // offsetC, ldc_strided, &beta, setup_ssr); + gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA, + lda_strided, transa, (double*)b, ldb, transb, + (double*)c + offsetC, ldc_strided, beta); + break; + case FP32: + gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided, + (float*)b, ldb, (float*)c + offsetC, ldc_strided, + &beta, setup_ssr); + break; + case FP16: + if (expand) { + gemm_fp16_ex_opt(frac_m, n, k, (__fp16*)a + offsetA, + lda_strided, (__fp16*)b, ldb, + (__fp16*)c + offsetC, ldc_strided, &beta, + setup_ssr); + } else { + gemm_fp16_opt(frac_m, n, k, (__fp16*)a + offsetA, + lda_strided, (__fp16*)b, ldb, + (__fp16*)c + offsetC, ldc_strided, &beta, + setup_ssr); + } + break; + case FP8: + gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, (char*)b, + ldb, (char*)c + offsetC, ldc_strided, &beta, + setup_ssr); + break; + } } } diff --git a/sw/dnn/flashattention_2/data/datagen.py b/sw/dnn/flashattention_2/data/datagen.py new file mode 100755 index 0000000000..5fde48b260 --- /dev/null +++ b/sw/dnn/flashattention_2/data/datagen.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +# Copyright 2023 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Viviane Potocnik +# Luca Colagrande + +import argparse +import numpy as np +import pathlib +import hjson +import sys +import os +import torch + +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/")) +import data_utils # noqa: E402 +from data_utils import emit_license, \ + format_struct_definition, format_array_definition, \ + format_array_declaration, format_ifdef_wrapper # noqa: E402 + +torch.manual_seed(42) + +# AXI splits bursts crossing 4KB address boundaries. To minimize +# the occurrence of these splits the data should be aligned to 4KB +BURST_ALIGNMENT = 4096 + +PRECISION = { + 'FP64': '64', + 'FP32': '32', + 'FP16': '16', + 'FP8': '8' +} + + +def torch_golden_model(Q, K, V): + return torch.nn.functional.scaled_dot_product_attention(Q, K, V) + + +def exact_golden_model(Q, K, V, B_r, B_c): + # Convert torch tensors to numpy arrays + Q = Q.numpy() + K = K.numpy() + V = V.numpy() + # Get layer dimensions + N = Q.shape[0] + # Calculate tiling parameters + T_r = N // B_r + T_c = N // B_c + # Transpose K + K_t = np.transpose(K) + # Iterate tiles + O_tiles = [] + for i in range(T_r): + # Tile Q + start_row = i * B_r + end_row = start_row + B_r + Q_i = Q[start_row:end_row, :] + # Initialize l_i, m_i, O_i + m_i = np.full((B_r, 1), -np.inf) + for j in range(T_c): + # Tile K_t and V + start_col = j * B_c + end_col = start_col + B_c + K_t_j = K_t[:, start_col:end_col] + V_j = V[start_col:end_col, ] + # Compute O tile update + S_ij = np.matmul(Q_i, K_t_j) + m_i_prev = m_i + m_i = np.maximum(m_i_prev, np.max(S_ij, 1, keepdims=True)) + shifted_exp = np.exp(m_i_prev - m_i) + P_ij = np.exp(S_ij - m_i) + PxV = np.matmul(P_ij, V_j) + if j == 0: + l_i = np.sum(P_ij, 1, keepdims=True) + O_i = PxV + else: + l_i = (shifted_exp * l_i) + np.sum(P_ij, 1, keepdims=True) + diag = np.diag(shifted_exp[:, 0]) + diag_inv = np.linalg.inv(diag) + O_i = np.matmul(diag_inv, O_i) + O_i += PxV + # Finalize O tile + diag_l_i = np.diag(l_i[:, 0]) + diag_l_inv_i = np.linalg.inv(diag_l_i) + O_i = np.matmul(diag_l_inv_i, O_i) + O_tiles.append(O_i) + return np.concatenate(O_tiles, 0) + + +def emit_header(section, params): + N = params['N'] + d = params['d'] + B_r = params['B_r'] + B_c = params['B_c'] + prec = PRECISION[params['dtype']] + + # Verify layer parameters are valid + assert (N % B_r) == 0, 'N is not an integer multiple of B_r' + assert (N % B_c) == 0, 'N is not an integer multiple of B_c' + assert (B_r % 8) == 0, 'B_r must be an integer multiple of the number of cores in a cluster' + + torch_type = data_utils.floating_point_torch_type(prec) + + Q = torch.rand(N, d, requires_grad=False, dtype=torch_type) + K = torch.rand(N, d, requires_grad=False, dtype=torch_type) + V = torch.rand(N, d, requires_grad=False, dtype=torch_type) + + output = exact_golden_model(Q, K, V, B_r, B_c) + + # Layer implementation assumes K is in (d, N) layout + K = torch.transpose(K, 0, 1) + + ctype = data_utils.floating_point_ctype(prec) + + q_uid = 'Q' + k_uid = 'K' + v_uid = 'V' + o_uid = 'O' + + layer_cfg = { + **params, + 'Q': q_uid, + 'K': k_uid, + 'V': v_uid, + 'O': o_uid, + } + + data_str = [emit_license()] + data_str += [format_array_declaration(ctype, q_uid, Q.shape)] + data_str += [format_array_declaration(ctype, k_uid, K.shape)] + data_str += [format_array_declaration(ctype, v_uid, V.shape)] + data_str += [format_array_declaration(ctype, o_uid, output.shape)] + data_str += [format_struct_definition('flashattention_2_layer_t', 'layer', layer_cfg)] + data_str += [format_array_definition(ctype, q_uid, Q)] + data_str += [format_array_definition(ctype, k_uid, K)] + data_str += [format_array_definition(ctype, v_uid, V)] + result_def = format_array_definition(ctype, 'golden', output) + data_str += [format_ifdef_wrapper('BIST', result_def)] + data_str = '\n\n'.join(data_str) + + return data_str + + +def main(): + + parser = argparse.ArgumentParser(description='Generate data for layernorm kernel') + parser.add_argument( + "-c", "--cfg", + type=pathlib.Path, + required=True, + help='Select param config file kernel' + ) + parser.add_argument( + '--section', + type=str, + help='Section to store matrices in') + parser.add_argument( + 'output', + type=pathlib.Path, + help='Path of the output header file') + args = parser.parse_args() + + # Load param config file + with args.cfg.open() as f: + param = hjson.loads(f.read()) + + # Emit header file + with open(args.output, 'w') as f: + f.write(emit_header(args.section, param)) + + +if __name__ == '__main__': + main() diff --git a/sw/dnn/flashattention_2/data/params.hjson b/sw/dnn/flashattention_2/data/params.hjson new file mode 100644 index 0000000000..6376dd4596 --- /dev/null +++ b/sw/dnn/flashattention_2/data/params.hjson @@ -0,0 +1,11 @@ +// Copyright 2020 ETH Zurich and University of Bologna. +// Solderpad Hardware License, Version 0.51, see LICENSE for details. +// SPDX-License-Identifier: SHL-0.51 + +{ + N: 128 + d: 128 + B_r: 32 + B_c: 16 + dtype: FP64 +} \ No newline at end of file diff --git a/sw/dnn/flashattention_2/src/flashattention_2.h b/sw/dnn/flashattention_2/src/flashattention_2.h new file mode 100644 index 0000000000..5d9f6c9cf3 --- /dev/null +++ b/sw/dnn/flashattention_2/src/flashattention_2.h @@ -0,0 +1,385 @@ +// Copyright 2020 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Viviane Potocnik +// Luca Colagrande + +#include "blas.h" +#include "snrt.h" + +// #define ENABLE_PRINTS + +/** + * @struct flashattention_2_layer_t + * @brief This structure contains all parameters necessary + * for computing a FlashAttention-2 layer. Refer to + * "FlashAttention-2: Faster Attention with Better + * Parallelism and Work Partitioning" for more info + * @var flashattention_2_layer_t::N + * Sequence length in number of tokens + * @var flashattention_2_layer_t::d + * Head dimension + * @var flashattention_2_layer_t::Q + * Pointer to query tensor + * @var flashattention_2_layer_t::K + * Pointer to key tensor + * @var flashattention_2_layer_t::V + * Pointer to value tensor + * @var flashattention_2_layer_t::O + * Pointer to output tensor + */ +typedef struct { + uint32_t N; + uint32_t d; + uint32_t B_r; + uint32_t B_c; + void *Q; + void *K; + void *V; + void *O; + precision_t dtype; +} flashattention_2_layer_t; + +static inline void flashattention_2_layer(flashattention_2_layer_t layer) { + // alias layer parameters + uint32_t N = layer.N; + uint32_t d = layer.d; + uint32_t B_r = layer.B_r; + uint32_t B_c = layer.B_c; + double *Q_l3 = layer.Q; + double *K_l3 = layer.K; + double *V_l3 = layer.V; + double *O_l3 = layer.O; + + // alias system parameters + // TODO adapt these for Occamy + uint32_t compute_id = snrt_global_core_idx(); + uint32_t cluster_id = snrt_cluster_idx(); + uint32_t num_cores = snrt_cluster_compute_core_num(); + uint32_t num_clusters = snrt_cluster_num(); + + // compute the tiling parameters + uint32_t T_r = N / B_r; // number of row blocks + uint32_t T_c = N / B_c; // number of column blocks + + // compute the size of the matrices + uint32_t q_fa_size = B_r * d * sizeof(double); + uint32_t k_fa_size = B_c * d * sizeof(double); + uint32_t v_fa_size = B_c * d * sizeof(double); + uint32_t s_fa_size = B_r * B_c * sizeof(double); + uint32_t p_fa_size = B_r * B_c * sizeof(double); + uint32_t o_fa_size = B_r * d * sizeof(double); + uint32_t m_i_size = B_r * sizeof(double); + uint32_t m_i_prev_size = m_i_size; + uint32_t l_i_size = B_r * sizeof(double); + uint32_t shifted_exp_size = B_r * sizeof(double); + + // allocate memory in TCDM + void *tcdm_ptr = (double *)snrt_l1_next(); + double *Q_fa = tcdm_ptr; + tcdm_ptr += q_fa_size; + double *K_fa = tcdm_ptr; + tcdm_ptr += k_fa_size; + double *V_fa = tcdm_ptr; + tcdm_ptr += v_fa_size; + double *S_fa = tcdm_ptr; + tcdm_ptr += s_fa_size; + double *P_fa = tcdm_ptr; + tcdm_ptr += p_fa_size; + double *O_fa = tcdm_ptr; + tcdm_ptr += o_fa_size; + double *m_i = tcdm_ptr; + tcdm_ptr += m_i_size; + double *m_i_prev = tcdm_ptr; + tcdm_ptr += m_i_prev_size; + double *l_i = tcdm_ptr; + tcdm_ptr += l_i_size; + double shifted_exp; + double row_sum; + + float used_memory_kB = + (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; + + DUMP(used_memory_kB); + + // Iterate row blocks of Q + uint32_t start_loop_outer = snrt_mcycle(); + for (int t_r = 0; t_r < T_r; t_r++) { + // DMA copy Q row block to TCDM + uint32_t start_dma = snrt_mcycle(); + if (snrt_is_dm_core()) { + uint32_t q_fa_offset = t_r * B_r * d; + + snrt_dma_txid_t txid_q_fa = + snrt_dma_start_2d(Q_fa, /* dst */ + Q_l3 + q_fa_offset, /* src */ + d * sizeof(double), /* size */ + d * sizeof(double), /* dst_stride */ + d * sizeof(double), /* src_stride */ + B_r); /* repetitions */ + + snrt_dma_wait_all(); + +#ifdef ENABLE_PRINTS + printf("Q_fa:\n"); + for (int i = 0; i < B_r * d; i++) { + DUMP((float)(Q_fa[i])); + } +#endif + } + uint32_t end_dma = snrt_mcycle(); + + snrt_cluster_hw_barrier(); + + // Initialize m_i, m_i_prev, l_i, row_sum + uint32_t rows_per_core = B_r / num_cores; + uint32_t start_row = rows_per_core * compute_id; + uint32_t end_row = start_row + rows_per_core; + if (snrt_is_compute_core()) { + for (int row_idx = start_row; row_idx < end_row; row_idx++) { + m_i[row_idx] = -INFINITY; + m_i_prev[row_idx] = -INFINITY; + l_i[row_idx] = 0.0f; + } + } + + snrt_cluster_hw_barrier(); + +#ifdef ENABLE_PRINTS + if (snrt_cluster_core_idx() == 0) { + printf("m_i:\n"); + for (int i = 0; i < B_r; i++) { + DUMP((float)(m_i[i])); + } + printf("l_i:\n"); + for (int i = 0; i < B_r; i++) { + DUMP((float)(l_i[i])); + } + } +#endif + + snrt_cluster_hw_barrier(); + + // Iterate column blocks of K (corresponding to row blocks of V) + uint32_t start_loop_inner = snrt_mcycle(); + for (int t_c = 0; t_c < T_c; t_c++) { +#ifdef ENABLE_PRINTS + if (snrt_cluster_core_idx() == 0) { + printf("(i, j) = (%d, %d)\n", t_r, t_c); + } +#endif + + snrt_cluster_hw_barrier(); + + // DMA copy K column block (d, B_c) and V row block (B_c, d) to + // TCDM. K is stored in (d, N) form in memory, V in (N, d) form + uint32_t k_fa_offset = t_c * B_c; + uint32_t v_fa_offset = t_c * B_c * d; + uint32_t start_dma = snrt_mcycle(); + if (!snrt_is_compute_core()) { + // K is in (d, N) format in main memory + snrt_dma_txid_t txid_k_fa = + snrt_dma_start_2d(K_fa, /* dst */ + K_l3 + k_fa_offset, /* src */ + B_c * sizeof(double), /* size */ + B_c * sizeof(double), /* dst_stride */ + N * sizeof(double), /* src_stride */ + d); /* repetitions */ + + snrt_dma_txid_t txid_v_fa = + snrt_dma_start_2d(V_fa, /* dst */ + V_l3 + v_fa_offset, /* src */ + d * sizeof(double), /* size */ + d * sizeof(double), /* dst_stride */ + d * sizeof(double), /* src_stride */ + B_c); /* repetitions */ + + snrt_dma_wait_all(); + +#ifdef ENABLE_PRINTS + printf("K_fa:\n"); + for (int i = 0; i < B_c * d; i++) { + DUMP((float)(K_fa[i])); + } + printf("V_fa:\n"); + for (int i = 0; i < B_c * d; i++) { + DUMP((float)(V_fa[i])); + } +#endif + } + uint32_t end_dma = snrt_mcycle(); + + snrt_cluster_hw_barrier(); + + // Calculate O tile from Q, K and V tiles + if (snrt_is_compute_core()) { + // Matrix multiplication between row block of Q and transposed + // column block of K to calculate a tile of S: S = Q * K^T. + // The S tile is of form (B_r, B_c) + uint32_t start_gemm = snrt_mcycle(); + gemm(FP64, 0, 0, 0, 0, B_r, B_c, d, 1, Q_fa, d, K_fa, B_c, 0, + S_fa, B_c); + uint32_t end_gemm = snrt_mcycle(); + + snrt_cluster_hw_barrier(); + +#ifdef ENABLE_PRINTS + if (snrt_cluster_core_idx() == 0) { + printf("S_ij:\n"); + for (int i = 0; i < B_r * B_c; i++) { + DUMP((float)(S_fa[i])); + } + } +#endif + + // Iterate over the rows of the S row block, distributing + // the rows to the cores + for (int row_idx = start_row; row_idx < end_row; row_idx++) { + // Save m of current tile to rescale next tile + m_i_prev[row_idx] = m_i[row_idx]; + + // Initialize "local" row_sum to zero + row_sum = 0.0; + + // Iterate over all columns to calculate maximum for the + // current row + for (int col_idx = 0; col_idx < B_c; col_idx++) { + double val = S_fa[row_idx * B_c + col_idx]; + if (val > m_i[row_idx]) m_i[row_idx] = val; + } + + // Calculate P tile as the "local" softmax of S + for (int col_idx = 0; col_idx < B_c; col_idx++) { + P_fa[row_idx * B_c + col_idx] = + expf(S_fa[row_idx * B_c + col_idx] - m_i[row_idx]); + row_sum += P_fa[row_idx * B_c + col_idx]; + } + + // Calculate rescaling factor l + shifted_exp = expf(m_i_prev[row_idx] - m_i[row_idx]); + l_i[row_idx] = l_i[row_idx] * shifted_exp + row_sum; + + // If not in first t_c iteration, update + // O_ij = diag(shifted_exp)^(-1) * O_i(j-1) + if (t_c != 0) { + for (int col_idx = 0; col_idx < d; col_idx++) { + O_fa[row_idx * d + col_idx] /= shifted_exp; + } + } + } + + snrt_cluster_hw_barrier(); + +#ifdef ENABLE_PRINTS + if (snrt_cluster_core_idx() == 0) { + printf("m_i:\n"); + for (int i = 0; i < B_r; i++) { + DUMP((float)(m_i[i])); + } + printf("P_ij:\n"); + for (int i = 0; i < B_r * B_c; i++) { + DUMP((float)(P_fa[i])); + } + printf("l_i:\n"); + for (int i = 0; i < B_r; i++) { + DUMP((float)(l_i[i])); + } + if (t_c != 0) { + printf("O_ij:\n"); + for (int i = 0; i < B_r * d; i++) { + DUMP((float)(O_fa[i])); + } + } + } +#endif + + snrt_cluster_hw_barrier(); + + // Calculate O tile (O_ij) of size (B_r, d). + // The P tile is of size (B_r, B_c) and V of size (B_c, d) + if (t_c == 0) { + // In first t_c iteration, initialize O_ij to P_ij * V_j + gemm(FP64, 0, 0, 0, 0, B_r, d, B_c, 1, P_fa, B_c, V_fa, d, + 0.0f, O_fa, d); + } else { + // In successive t_c iterations, O_ij += P_ij * V_j + gemm(FP64, 0, 0, 0, 0, B_r, d, B_c, 1, P_fa, B_c, V_fa, d, + 1.0f, O_fa, d); + } + + uint32_t end_stats = snrt_mcycle(); + + snrt_cluster_hw_barrier(); + +#ifdef ENABLE_PRINTS + if (snrt_cluster_core_idx() == 0) { + printf("O_ij:\n"); + for (int i = 0; i < B_r * d; i++) { + DUMP((float)(O_fa[i])); + } + } +#endif + } else { + snrt_cluster_hw_barrier(); + snrt_cluster_hw_barrier(); + snrt_cluster_hw_barrier(); + snrt_cluster_hw_barrier(); + } + } // end of T_c loop + + snrt_cluster_hw_barrier(); + + // Rescaling for last t_c iteration + // O_i = diag(l_i_Tc)^-1 * O_i + if (snrt_is_compute_core()) { + for (int row_idx = start_row; row_idx < end_row; row_idx++) { + for (int col_idx = 0; col_idx < d; col_idx++) { + O_fa[row_idx * d + col_idx] /= l_i[row_idx]; + } + } + } + + snrt_fpu_fence(); + + snrt_cluster_hw_barrier(); + +#ifdef ENABLE_PRINTS + if (snrt_cluster_core_idx() == 0) { + printf("O_ij:\n"); + for (int i = 0; i < B_r * d; i++) { + DUMP((float)(O_fa[i])); + } + } +#endif + + snrt_cluster_hw_barrier(); + + // Write back O row block (B_r, d) to DRAM + uint32_t start_dma_write_back = snrt_mcycle(); + if (snrt_is_dm_core()) { + uint32_t o_fa_offset = t_r * B_r * d; + snrt_dma_txid_t txid_o_fa = + snrt_dma_start_2d(O_l3 + o_fa_offset, /* dst */ + O_fa, /* src */ + d * sizeof(double), /* size */ + d * sizeof(double), /* dst_stride */ + d * sizeof(double), /* src_stride */ + B_r); /* repetitions */ + + snrt_dma_wait_all(); + +#ifdef ENABLE_PRINTS + printf("O_l3:\n"); + for (int i = 0; i < B_r * d; i++) { + DUMP((float)(O_l3[o_fa_offset + i])); + } +#endif + } + uint32_t end_dma_write_back = snrt_mcycle(); + + } // end of T_r loop + uint32_t end_loop_outer = snrt_mcycle(); + + snrt_cluster_hw_barrier(); +} diff --git a/sw/dnn/flashattention_2/src/main.c b/sw/dnn/flashattention_2/src/main.c new file mode 100644 index 0000000000..daadeede6b --- /dev/null +++ b/sw/dnn/flashattention_2/src/main.c @@ -0,0 +1,14 @@ +// Copyright 2023 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Luca Colagrande + +#include "dnn.h" + +#include "data.h" + +int main() { + flashattention_2_layer(layer); + return 0; +} diff --git a/sw/dnn/flashattention_2/verify.py b/sw/dnn/flashattention_2/verify.py new file mode 100755 index 0000000000..1e141fc92a --- /dev/null +++ b/sw/dnn/flashattention_2/verify.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# Copyright 2023 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Luca Colagrande + +import sys +from pathlib import Path +import numpy as np +import torch +from data.datagen import exact_golden_model + +sys.path.append(str(Path(__file__).parent / '../../../util/sim/')) +import verification # noqa: E402 +from elf import Elf # noqa: E402 +from data_utils import bytes_to_float, bytes_to_struct # noqa: E402 + + +ERR_THRESHOLD = 1E-6 + +PRECISION_T = { + 8: '64', + 4: '32', + 2: '16', + 1: '8' +} + +NUMPY_T = { + '64': np.float64, + '32': np.float32, + '16': np.float16 +} + + +def main(): + # Run simulation and get outputs + args = verification.parse_args() + raw_results = verification.simulate(sim_bin=args.sim_bin, + snitch_bin=args.snitch_bin, + symbols_bin=args.symbols_bin, + log=args.log, + output_uids=['O']) + + # Extract input operands from ELF file + if args.symbols_bin: + elf = Elf(args.symbols_bin) + else: + elf = Elf(args.snitch_bin) + + layer_struct = { + 'N': 'I', + 'd': 'I', + 'B_r': 'I', + 'B_c': 'I', + 'Q': 'I', + 'K': 'I', + 'V': 'I', + 'O': 'I', + 'dtype': 'I' + } + layer = bytes_to_struct(elf.get_symbol_contents('layer'), layer_struct) + N = layer['N'] + d = layer['d'] + B_r = layer['B_r'] + B_c = layer['B_c'] + prec = PRECISION_T[layer['dtype']] + + Q = np.array(bytes_to_float(elf.get_symbol_contents('Q'), prec), dtype=NUMPY_T[prec]) + K = np.array(bytes_to_float(elf.get_symbol_contents('K'), prec), dtype=NUMPY_T[prec]) + V = np.array(bytes_to_float(elf.get_symbol_contents('V'), prec), dtype=NUMPY_T[prec]) + Q = torch.from_numpy(Q.reshape(N, d)) + V = torch.from_numpy(V.reshape(N, d)) + # Golden model expects key matrix in (N, d) form, while Snitch binary stores it in (d, N) + K = torch.from_numpy(K.reshape(d, N)) + K = torch.transpose(K, 0, 1) + + # Verify results + O_actual = np.array(bytes_to_float(raw_results['O'], prec), dtype=NUMPY_T[prec]) + O_golden = exact_golden_model(Q, K, V, B_r, B_c).flatten() + # O_golden = torch_golden_model(Q, K, V).detach().numpy().flatten() + + relative_err = np.absolute((O_golden - O_actual) / O_golden) + fail = np.any(relative_err > ERR_THRESHOLD) + if (fail): + verification.dump_results_to_csv([O_golden, O_actual, relative_err], + Path.cwd() / 'flashattention_2_results.csv') + print('Maximum relative error:', np.max(relative_err)) + + return int(fail) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/sw/dnn/layernorm/src/layernorm.h b/sw/dnn/layernorm/src/layernorm.h index e09771541a..9dff3cf795 100644 --- a/sw/dnn/layernorm/src/layernorm.h +++ b/sw/dnn/layernorm/src/layernorm.h @@ -4,10 +4,9 @@ #pragma once +#include "dnn.h" #include "math.h" #include "snrt.h" -// #include "printf.h" -#include "dnn.h" /** * @struct layernorm_layer_struct diff --git a/sw/dnn/requirements.txt b/sw/dnn/requirements.txt index 7524fee3a7..a05999073a 100644 --- a/sw/dnn/requirements.txt +++ b/sw/dnn/requirements.txt @@ -3,5 +3,5 @@ # SPDX-License-Identifier: Apache-2.0 numpy -torch>=1.8 +torch==2.1 hjson diff --git a/sw/dnn/src/dnn.h b/sw/dnn/src/dnn.h index a38cc9c76e..5c41041c83 100644 --- a/sw/dnn/src/dnn.h +++ b/sw/dnn/src/dnn.h @@ -197,6 +197,7 @@ typedef struct network_single_cluster_t_ { // #include "conv2d.h" #include "../batchnorm/src/batchnorm.h" +#include "../flashattention_2/src/flashattention_2.h" #include "../gelu/src/gelu.h" #include "../gemm/src/gemm.h" #include "../layernorm/src/layernorm.h" diff --git a/target/snitch_cluster/sw/apps/Makefile b/target/snitch_cluster/sw/apps/Makefile index 1a8ce51ae2..596b37e4ff 100644 --- a/target/snitch_cluster/sw/apps/Makefile +++ b/target/snitch_cluster/sw/apps/Makefile @@ -18,6 +18,7 @@ SUBDIRS += dnn/layernorm SUBDIRS += dnn/linear SUBDIRS += dnn/maxpool SUBDIRS += dnn/softmax +SUBDIRS += dnn/flashattention_2 SUBDIRS += montecarlo/pi_estimation .PHONY: all clean $(SUBDIRS) diff --git a/target/snitch_cluster/sw/apps/dnn/flashattention_2/Makefile b/target/snitch_cluster/sw/apps/dnn/flashattention_2/Makefile new file mode 100644 index 0000000000..75585b2d65 --- /dev/null +++ b/target/snitch_cluster/sw/apps/dnn/flashattention_2/Makefile @@ -0,0 +1,12 @@ +# Copyright 2023 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Luca Colagrande + +APP ?= flashattention_2 + +include ../../../../../../sw/dnn/common.mk +include ../../common.mk + +$(DEP): $(DATA_H) diff --git a/target/snitch_cluster/sw/run.yaml b/target/snitch_cluster/sw/run.yaml index 3d3629c90d..8156be18d3 100644 --- a/target/snitch_cluster/sw/run.yaml +++ b/target/snitch_cluster/sw/run.yaml @@ -84,4 +84,7 @@ runs: # - elf: apps/dnn/fusedconv/build/fusedconv.elf # fails newly - elf: apps/dnn/softmax/build/softmax.elf # Illegal FDIV without FDIV unit cmd: ../../sw/dnn/softmax/verify.py {sim_bin} {elf} + # Illegal FDIV without FDIV unit + - elf: apps/dnn/flashattention_2/build/flashattention_2.elf + cmd: ../../sw/dnn/flashattention_2/verify.py {sim_bin} {elf} - elf: apps/montecarlo/pi_estimation/build/pi_estimation.elf diff --git a/target/snitch_cluster/sw/runtime/banshee/src/dump.h b/target/snitch_cluster/sw/runtime/banshee/src/dump.h deleted file mode 100644 index 8f24cc1b92..0000000000 --- a/target/snitch_cluster/sw/runtime/banshee/src/dump.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2020 ETH Zurich and University of Bologna. -// Licensed under the Apache License, Version 2.0, see LICENSE for details. -// SPDX-License-Identifier: Apache-2.0 -// -// Authors: Samuel Riedel, ETH Zurich -// Viviane Potocnik, ETH Zurich - -// Dump a value via CSR -// !!! Careful: This is only supported in simulation and an experimental -// feature. All writes to unimplemented CSR registers will be dumped by Snitch. -// This can be exploited to quickly print measurement values from all cores -// simultaneously without the hassle of printf. To specify multiple metrics, -// different CSRs can be used. The macro will define a function that will then -// always print via the same CSR. E.g., `dump(errors, 8)` will define a function -// with the following signature: `dump_errors(uint32_t val)`, which will print -// the given value via the 8th register. Alternatively, the `write_csr(reg, -// val)` macro can be used directly. - -#define dump_float(name, reg) \ - static __attribute__((always_inline)) inline void dump_##name(float val) { \ - asm volatile("csrw " #reg ", %0" ::"rK"(val)); \ - } - -#define dump_uint(name, reg) \ - static \ - __attribute__((always_inline)) inline void dump_##name(uint32_t val) { \ - asm volatile("csrw " #reg ", %0" ::"rK"(val)); \ - } \ No newline at end of file diff --git a/target/snitch_cluster/sw/toolchain.mk b/target/snitch_cluster/sw/toolchain.mk index ca5f6a4820..d8e3aa6cb7 100644 --- a/target/snitch_cluster/sw/toolchain.mk +++ b/target/snitch_cluster/sw/toolchain.mk @@ -32,7 +32,7 @@ RISCV_CFLAGS += -menable-experimental-extensions RISCV_CFLAGS += -mabi=ilp32d RISCV_CFLAGS += -mcmodel=medany # RISCV_CFLAGS += -mno-fdiv # Not supported by Clang -RISCV_CFLAGS += -ffast-math +# RISCV_CFLAGS += -ffast-math RISCV_CFLAGS += -fno-builtin-printf RISCV_CFLAGS += -fno-builtin-sqrtf RISCV_CFLAGS += -fno-common