Skip to content

Commit

Permalink
dnn: Add FlashAttention-2 layer
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Nov 1, 2023
1 parent 9b6d35d commit 2bea0fe
Show file tree
Hide file tree
Showing 12 changed files with 695 additions and 12 deletions.
35 changes: 26 additions & 9 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,33 @@ typedef __fp16 v4f16 __attribute__((vector_size(8)));
typedef char v8f8 __attribute__((vector_size(8)));
#endif

dump_float(gemm, 8);
dump_uint(index, 9);
// Floating-point multiplications by zero cannot be optimized as in some
// edge cases they do not yield zero:
// - 0f * NaN = NaN
// - 0f * INFINITY == NaN
// Thus in order to optimize it, we need to test for zero. You can use this
// function for free when `multiplier` is a constant.
static inline double multiply_opt(double multiplicand, double multiplier) {
if (multiplier)
return multiplicand * multiplier;
else
return 0;
}

void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
uint32_t ldA, uint32_t ta, double* B, uint32_t ldB,
uint32_t tb, double* C, uint32_t ldC, double BETA) {
if (!ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
double c0 = BETA * C[m * ldC + n];
double c0 = multiply_opt(C[m * ldC + n], BETA);
for (uint32_t k = 0; k < K; k++) {
// dump_index(k + m * ldA);
// dump_gemm(A[k + m * ldA]);

// if (snrt_cluster_core_idx() == 7) {
// printf("k = %d, m = %d, n = %d, ldA = %d, ldB = %d\n", k, m, n, ldA, ldB);
// }
c0 += A[k + m * ldA] * B[k * ldB + n];
}
C[m * ldC + n] = c0;
Expand All @@ -42,7 +56,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else if (ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = BETA * C[m * ldC + n];
double c0 = multiply_opt(C[m * ldC + n], BETA);
for (uint32_t k = 0; k < K; k++) {
c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n];
}
Expand All @@ -52,7 +66,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else if (!ta && tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = BETA * C[m * ldC + n];
double c0 = multiply_opt(C[m * ldC + n], BETA);
for (uint32_t k = 0; k < K; k++) {
c0 += A[k + m * ldA] * B[k + n * ldB];
}
Expand All @@ -62,7 +76,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = BETA * C[m * ldC + n];
double c0 = multiply_opt(C[m * ldC + n], BETA);
for (uint32_t k = 0; k < K; k++) {
c0 += A[k * M * ldA + m * ldA] * B[k + n * ldB];
}
Expand Down Expand Up @@ -1015,9 +1029,12 @@ void gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,

switch (prec) {
case FP64:
gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided,
transa, (double*)b, ldb, transb, (double*)c + offsetC,
ldc_strided, &beta, setup_ssr);
// gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided,
// transa, (double*)b, ldb, transb, (double*)c + offsetC,
// ldc_strided, &beta, setup_ssr);
gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA, lda_strided,
transa, (double*)b, ldb, transb, (double*)c + offsetC,
ldc_strided, beta);
break;
case FP32:
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided,
Expand Down
175 changes: 175 additions & 0 deletions sw/dnn/flashattention_2/data/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/usr/bin/env python3
# Copyright 2023 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Viviane Potocnik <[email protected]>
# Luca Colagrande <[email protected]>

import argparse
import numpy as np
import pathlib
import hjson
import sys
import os
import torch

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
import data_utils # noqa: E402
from data_utils import emit_license, \
format_struct_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper # noqa: E402

torch.manual_seed(42)

# AXI splits bursts crossing 4KB address boundaries. To minimize
# the occurrence of these splits the data should be aligned to 4KB
BURST_ALIGNMENT = 4096

PRECISION = {
'FP64': '64',
'FP32': '32',
'FP16': '16',
'FP8': '8'
}


def torch_golden_model(Q, K, V):
return torch.nn.functional.scaled_dot_product_attention(Q, K, V)


def exact_golden_model(Q, K, V, B_r, B_c):
# Convert torch tensors to numpy arrays
Q = Q.numpy()
K = K.numpy()
V = V.numpy()
# Get layer dimensions
N = Q.shape[0]
d = Q.shape[1]
# Calculate tiling parameters
T_r = N // B_r
T_c = N // B_c
# Transpose K
K_t = np.transpose(K)
# Iterate tiles
O = []
for i in range(T_r):
# Tile Q
start_row = i * B_r
end_row = start_row + B_r
Q_i = Q[start_row:end_row,:]
# Initialize l_i, m_i, O_i
m_i = np.full((B_r, 1), -np.inf)
for j in range(T_c):
# Tile K_t and V
start_col = j * B_c
end_col = start_col + B_c
K_t_j = K_t[:,start_col:end_col]
V_j = V[start_col:end_col,]
# Compute O tile update
S_ij = np.matmul(Q_i, K_t_j)
m_i_prev = m_i
m_i = np.maximum(m_i_prev, np.max(S_ij, 1, keepdims=True))
shifted_exp = np.exp(m_i_prev - m_i)
P_ij = np.exp(S_ij - m_i)
if j == 0:
l_i = np.sum(P_ij, 1, keepdims=True)
O_i = np.matmul(P_ij, V_j)
else:
l_i = (shifted_exp * l_i) + np.sum(P_ij, 1, keepdims=True)
diag = np.diag(shifted_exp[:, 0])
diag_inv = np.linalg.inv(diag)
O_i = np.matmul(diag_inv, O_i) + np.matmul(P_ij, V_j)
# Finalize O tile
diag_l_i = np.diag(l_i[:, 0])
diag_l_inv_i = np.linalg.inv(diag_l_i)
O_i = np.matmul(diag_l_inv_i, O_i)
O.append(O_i)
return np.concatenate(O, 0)


def emit_header(section, params):
batch_size = 1
N = params['N']
d = params['d']
B_r = params['B_r']
B_c = params['B_c']
prec = PRECISION[params['dtype']]

# Verify layer parameters are valid
assert (N % B_r) == 0, 'N is not an integer multiple of B_r'
assert (N % B_c) == 0, 'N is not an integer multiple of B_c'
assert (B_r % 8) == 0, 'B_r must be an integer multiple of the number of cores in a cluster'

torch_type = data_utils.floating_point_torch_type(prec)

Q = torch.rand(N, d, requires_grad=False, dtype=torch_type)
K = torch.rand(N, d, requires_grad=False, dtype=torch_type)
V = torch.rand(N, d, requires_grad=False, dtype=torch_type)

O = exact_golden_model(Q, K, V, B_r, B_c)

# Layer implementation assumes K is in (d, N) layout
K = torch.transpose(K, 0, 1)

ctype = data_utils.floating_point_ctype(prec)

q_uid = 'Q'
k_uid = 'K'
v_uid = 'V'
o_uid = 'O'

layer_cfg = {
**params,
'Q': q_uid,
'K': k_uid,
'V': v_uid,
'O': o_uid,
}

data_str = [emit_license()]
data_str += [format_array_declaration(ctype, q_uid, Q.shape)]
data_str += [format_array_declaration(ctype, k_uid, K.shape)]
data_str += [format_array_declaration(ctype, v_uid, V.shape)]
data_str += [format_array_declaration(ctype, o_uid, O.shape)]
data_str += [format_struct_definition('flashattention_2_layer_t', 'layer', layer_cfg)]
data_str += [format_array_definition(ctype, q_uid, Q)]
data_str += [format_array_definition(ctype, k_uid, K)]
data_str += [format_array_definition(ctype, v_uid, V)]
result_def = format_array_definition(ctype, 'golden', O)
data_str += [format_ifdef_wrapper('BIST', result_def)]
data_str = '\n\n'.join(data_str)

return data_str


def main():

parser = argparse.ArgumentParser(description='Generate data for layernorm kernel')
parser.add_argument(
"-c", "--cfg",
type=pathlib.Path,
required=True,
help='Select param config file kernel'
)
parser.add_argument(
'--section',
type=str,
help='Section to store matrices in')
parser.add_argument(
'output',
type=pathlib.Path,
help='Path of the output header file')
args = parser.parse_args()

# Load param config file
with args.cfg.open() as f:
param = hjson.loads(f.read())

# Emit header file
with open(args.output, 'w') as f:
f.write(emit_header(args.section, param))


if __name__ == '__main__':
main()
11 changes: 11 additions & 0 deletions sw/dnn/flashattention_2/data/params.hjson
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright 2020 ETH Zurich and University of Bologna.
// Solderpad Hardware License, Version 0.51, see LICENSE for details.
// SPDX-License-Identifier: SHL-0.51

{
N: 16
d: 16
B_r: 8
B_c: 8
dtype: FP64
}
Loading

0 comments on commit 2bea0fe

Please sign in to comment.