Skip to content

Commit

Permalink
dnn: Add FlashAttention-2 layer
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Nov 2, 2023
1 parent 9b6d35d commit 9e23aac
Show file tree
Hide file tree
Showing 14 changed files with 769 additions and 79 deletions.
117 changes: 70 additions & 47 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,34 @@ typedef __fp16 v4f16 __attribute__((vector_size(8)));
typedef char v8f8 __attribute__((vector_size(8)));
#endif

dump_float(gemm, 8);
dump_uint(index, 9);
// Floating-point multiplications by zero cannot be optimized as in some
// edge cases they do not yield zero:
// - 0f * NaN = NaN
// - 0f * INFINITY == NaN
// Thus in order to optimize it, we need to test for zero. You can use this
// function for free when `multiplier` is a constant.
static inline double multiply_opt(double multiplicand, double multiplier) {
if (multiplier)
return multiplicand * multiplier;
else
return 0;
}

void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
uint32_t ldA, uint32_t ta, double* B, uint32_t ldB,
uint32_t tb, double* C, uint32_t ldC, double BETA) {
if (!ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
double c0 = BETA * C[m * ldC + n];
double c0 = multiply_opt(C[m * ldC + n], BETA);
for (uint32_t k = 0; k < K; k++) {
// dump_index(k + m * ldA);
// dump_gemm(A[k + m * ldA]);

// if (snrt_cluster_core_idx() == 7) {
// printf("k = %d, m = %d, n = %d, ldA = %d, ldB =
// %d\n", k, m, n, ldA, ldB);
// }
c0 += A[k + m * ldA] * B[k * ldB + n];
}
C[m * ldC + n] = c0;
Expand All @@ -42,7 +57,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else if (ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = BETA * C[m * ldC + n];
double c0 = multiply_opt(C[m * ldC + n], BETA);
for (uint32_t k = 0; k < K; k++) {
c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n];
}
Expand All @@ -52,7 +67,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else if (!ta && tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = BETA * C[m * ldC + n];
double c0 = multiply_opt(C[m * ldC + n], BETA);
for (uint32_t k = 0; k < K; k++) {
c0 += A[k + m * ldA] * B[k + n * ldB];
}
Expand All @@ -62,7 +77,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
} else {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = BETA * C[m * ldC + n];
double c0 = multiply_opt(C[m * ldC + n], BETA);
for (uint32_t k = 0; k < K; k++) {
c0 += A[k * M * ldA + m * ldA] * B[k + n * ldB];
}
Expand Down Expand Up @@ -999,46 +1014,54 @@ void gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
uint32_t transa, uint32_t transb, uint32_t m, uint32_t n, uint32_t k,
double alpha, void* a, uint32_t lda, void* b, uint32_t ldb,
uint32_t beta, void* c, uint32_t ldc) {
const uint32_t compute_num = snrt_cluster_compute_core_num();
const uint32_t compute_id = snrt_cluster_core_idx();

// Compute cores work not on contiguous blocks but on strided rows
uint32_t lda_strided = compute_num * lda;
uint32_t ldc_strided = compute_num * ldc;

// Compute cores access A and C at offsets of one row from each other
uint32_t offsetA = compute_id * lda;
uint32_t offsetC = compute_id * ldc;

// Compute fraction of C rows every core computes
uint32_t frac_m = m / compute_num;

switch (prec) {
case FP64:
gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided,
transa, (double*)b, ldb, transb, (double*)c + offsetC,
ldc_strided, &beta, setup_ssr);
break;
case FP32:
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided,
(float*)b, ldb, (float*)c + offsetC, ldc_strided,
&beta, setup_ssr);
break;
case FP16:
if (expand) {
gemm_fp16_ex_opt(
frac_m, n, k, (__fp16*)a + offsetA, lda_strided, (__fp16*)b,
ldb, (__fp16*)c + offsetC, ldc_strided, &beta, setup_ssr);
} else {
gemm_fp16_opt(frac_m, n, k, (__fp16*)a + offsetA, lda_strided,
(__fp16*)b, ldb, (__fp16*)c + offsetC,
ldc_strided, &beta, setup_ssr);
}
break;
case FP8:
gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, (char*)b,
ldb, (char*)c + offsetC, ldc_strided, &beta,
setup_ssr);
break;
if (snrt_is_compute_core()) {
const uint32_t compute_num = snrt_cluster_compute_core_num();
const uint32_t compute_id = snrt_cluster_core_idx();

// Compute cores work not on contiguous blocks but on strided rows
uint32_t lda_strided = compute_num * lda;
uint32_t ldc_strided = compute_num * ldc;

// Compute cores access A and C at offsets of one row from each other
uint32_t offsetA = compute_id * lda;
uint32_t offsetC = compute_id * ldc;

// Compute fraction of C rows every core computes
uint32_t frac_m = m / compute_num;

switch (prec) {
case FP64:
// gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA,
// lda_strided,
// transa, (double*)b, ldb, transb, (double*)c +
// offsetC, ldc_strided, &beta, setup_ssr);
gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA,
lda_strided, transa, (double*)b, ldb, transb,
(double*)c + offsetC, ldc_strided, beta);
break;
case FP32:
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided,
(float*)b, ldb, (float*)c + offsetC, ldc_strided,
&beta, setup_ssr);
break;
case FP16:
if (expand) {
gemm_fp16_ex_opt(frac_m, n, k, (__fp16*)a + offsetA,
lda_strided, (__fp16*)b, ldb,
(__fp16*)c + offsetC, ldc_strided, &beta,
setup_ssr);
} else {
gemm_fp16_opt(frac_m, n, k, (__fp16*)a + offsetA,
lda_strided, (__fp16*)b, ldb,
(__fp16*)c + offsetC, ldc_strided, &beta,
setup_ssr);
}
break;
case FP8:
gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, (char*)b,
ldb, (char*)c + offsetC, ldc_strided, &beta,
setup_ssr);
break;
}
}
}
175 changes: 175 additions & 0 deletions sw/dnn/flashattention_2/data/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/usr/bin/env python3
# Copyright 2023 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Viviane Potocnik <[email protected]>
# Luca Colagrande <[email protected]>

import argparse
import numpy as np
import pathlib
import hjson
import sys
import os
import torch

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
import data_utils # noqa: E402
from data_utils import emit_license, \
format_struct_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper # noqa: E402

torch.manual_seed(42)

# AXI splits bursts crossing 4KB address boundaries. To minimize
# the occurrence of these splits the data should be aligned to 4KB
BURST_ALIGNMENT = 4096

PRECISION = {
'FP64': '64',
'FP32': '32',
'FP16': '16',
'FP8': '8'
}


def torch_golden_model(Q, K, V):
return torch.nn.functional.scaled_dot_product_attention(Q, K, V)


def exact_golden_model(Q, K, V, B_r, B_c):
# Convert torch tensors to numpy arrays
Q = Q.numpy()
K = K.numpy()
V = V.numpy()
# Get layer dimensions
N = Q.shape[0]
# Calculate tiling parameters
T_r = N // B_r
T_c = N // B_c
# Transpose K
K_t = np.transpose(K)
# Iterate tiles
O_tiles = []
for i in range(T_r):
# Tile Q
start_row = i * B_r
end_row = start_row + B_r
Q_i = Q[start_row:end_row, :]
# Initialize l_i, m_i, O_i
m_i = np.full((B_r, 1), -np.inf)
for j in range(T_c):
# Tile K_t and V
start_col = j * B_c
end_col = start_col + B_c
K_t_j = K_t[:, start_col:end_col]
V_j = V[start_col:end_col, ]
# Compute O tile update
S_ij = np.matmul(Q_i, K_t_j)
m_i_prev = m_i
m_i = np.maximum(m_i_prev, np.max(S_ij, 1, keepdims=True))
shifted_exp = np.exp(m_i_prev - m_i)
P_ij = np.exp(S_ij - m_i)
PxV = np.matmul(P_ij, V_j)
if j == 0:
l_i = np.sum(P_ij, 1, keepdims=True)
O_i = PxV
else:
l_i = (shifted_exp * l_i) + np.sum(P_ij, 1, keepdims=True)
diag = np.diag(shifted_exp[:, 0])
diag_inv = np.linalg.inv(diag)
O_i = np.matmul(diag_inv, O_i)
O_i += PxV
# Finalize O tile
diag_l_i = np.diag(l_i[:, 0])
diag_l_inv_i = np.linalg.inv(diag_l_i)
O_i = np.matmul(diag_l_inv_i, O_i)
O_tiles.append(O_i)
return np.concatenate(O_tiles, 0)


def emit_header(section, params):
N = params['N']
d = params['d']
B_r = params['B_r']
B_c = params['B_c']
prec = PRECISION[params['dtype']]

# Verify layer parameters are valid
assert (N % B_r) == 0, 'N is not an integer multiple of B_r'
assert (N % B_c) == 0, 'N is not an integer multiple of B_c'
assert (B_r % 8) == 0, 'B_r must be an integer multiple of the number of cores in a cluster'

torch_type = data_utils.floating_point_torch_type(prec)

Q = torch.rand(N, d, requires_grad=False, dtype=torch_type)
K = torch.rand(N, d, requires_grad=False, dtype=torch_type)
V = torch.rand(N, d, requires_grad=False, dtype=torch_type)

output = exact_golden_model(Q, K, V, B_r, B_c)

# Layer implementation assumes K is in (d, N) layout
K = torch.transpose(K, 0, 1)

ctype = data_utils.floating_point_ctype(prec)

q_uid = 'Q'
k_uid = 'K'
v_uid = 'V'
o_uid = 'O'

layer_cfg = {
**params,
'Q': q_uid,
'K': k_uid,
'V': v_uid,
'O': o_uid,
}

data_str = [emit_license()]
data_str += [format_array_declaration(ctype, q_uid, Q.shape)]
data_str += [format_array_declaration(ctype, k_uid, K.shape)]
data_str += [format_array_declaration(ctype, v_uid, V.shape)]
data_str += [format_array_declaration(ctype, o_uid, output.shape)]
data_str += [format_struct_definition('flashattention_2_layer_t', 'layer', layer_cfg)]
data_str += [format_array_definition(ctype, q_uid, Q)]
data_str += [format_array_definition(ctype, k_uid, K)]
data_str += [format_array_definition(ctype, v_uid, V)]
result_def = format_array_definition(ctype, 'golden', output)
data_str += [format_ifdef_wrapper('BIST', result_def)]
data_str = '\n\n'.join(data_str)

return data_str


def main():

parser = argparse.ArgumentParser(description='Generate data for layernorm kernel')
parser.add_argument(
"-c", "--cfg",
type=pathlib.Path,
required=True,
help='Select param config file kernel'
)
parser.add_argument(
'--section',
type=str,
help='Section to store matrices in')
parser.add_argument(
'output',
type=pathlib.Path,
help='Path of the output header file')
args = parser.parse_args()

# Load param config file
with args.cfg.open() as f:
param = hjson.loads(f.read())

# Emit header file
with open(args.output, 'w') as f:
f.write(emit_header(args.section, param))


if __name__ == '__main__':
main()
11 changes: 11 additions & 0 deletions sw/dnn/flashattention_2/data/params.hjson
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright 2020 ETH Zurich and University of Bologna.
// Solderpad Hardware License, Version 0.51, see LICENSE for details.
// SPDX-License-Identifier: SHL-0.51

{
N: 128
d: 128
B_r: 32
B_c: 16
dtype: FP64
}
Loading

0 comments on commit 9e23aac

Please sign in to comment.