Skip to content

Commit

Permalink
trafo: add dram data section and fix memory layouts
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Sep 16, 2023
1 parent 6a5feb7 commit f7e07b4
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 47 deletions.
103 changes: 64 additions & 39 deletions sw/apps/transformer/src/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ typedef struct transformer_layer_struct {
float *weights_q;
float *weights_k;
float *weights_v;
float *weights_o;
float *Q_lin;
float *K_lin;
float *V_lin;
float *O;
float *ofmap;
float *query;
float *key;
Expand All @@ -83,6 +88,10 @@ static inline void transformer_layer(transformer_layer_t *const l) {
uint32_t compute_id = snrt_global_core_idx();
uint32_t num_cores = snrt_cluster_compute_core_num();

/////////////////////////////////////////////////////////////////////
////////////// MULTI-HEAD SELF-ATTENTION LAYER /////////////////
/////////////////////////////////////////////////////////////////////


uint32_t S_TILE = 24;
uint32_t P_TILE = 4;
Expand All @@ -102,13 +111,7 @@ static inline void transformer_layer(transformer_layer_t *const l) {
uint32_t q_lin_size = l->seq_len * l->positional_embeddings * sizeof(float);
uint32_t k_lin_size = l->positional_embeddings * l->seq_len * sizeof(float);
uint32_t v_lin_size = l->positional_embeddings * l->seq_len * sizeof(float);
uint32_t attention_size = l->seq_len * l->seq_len * sizeof(float);
uint32_t output_size = l->seq_len * l->positional_embeddings * sizeof(float);
float *q_lin[q_lin_size];
float *k_lin[k_lin_size];
float *v_lin[v_lin_size];
float *attention_out[attention_size];
float *output[output_size];

void *tcdm_ptr = (float *)snrt_l1_next();
float *ifmap = tcdm_ptr;
Expand Down Expand Up @@ -150,7 +153,7 @@ static inline void transformer_layer(transformer_layer_t *const l) {
snrt_dma_start_2d(
ifmap, /* dst */
l->ifmap + ifmap_offset, /* src */
ifmap_size, /* size */
S_TILE * sizeof(float), /* size */
l->embeddings, /* dst_stride */
l->embeddings, /* src_stride */
l->seq_len); /* repetitions */
Expand Down Expand Up @@ -291,6 +294,7 @@ static inline void transformer_layer(transformer_layer_t *const l) {
uint32_t row_offset = (s_tile * S_TILE + (S_TILE / num_cores) * compute_id) * l->embeddings;

// compute a tile of the query matrix
// ifmap (S_TILE x E) * weights_q (E x P) = query (S_TILE x P)
if (p_tile == 0) {
gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings,
&ifmap[row_offset], l->embeddings, 0, weights_q,
Expand Down Expand Up @@ -414,32 +418,32 @@ static inline void transformer_layer(transformer_layer_t *const l) {
// write back the query matrix
snrt_dma_txid_t txid_q_lin =
snrt_dma_start_2d(
q_lin + q_lin_offset, /* dst */
&l->Q_lin[q_lin_offset], /* dst */
query, /* src */
P_TILE * sizeof(float), /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
P_TILE * sizeof(float), /* src_stride */
S_TILE / num_cores); /* repetitions */
S_TILE ); /* repetitions */

// write back the key matrix
snrt_dma_txid_t txid_k_lin =
snrt_dma_start_2d(
k_lin + q_lin_offset, /* dst */
&l->K_lin[q_lin_offset], /* dst */
key, /* src */
P_TILE * sizeof(float), /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
P_TILE * sizeof(float), /* src_stride */
S_TILE / num_cores); /* repetitions */
S_TILE ); /* repetitions */

// write back the value matrix
snrt_dma_txid_t txid_v_lin =
snrt_dma_start_2d(
v_lin + q_lin_offset, /* dst */
&l->V_lin[q_lin_offset], /* dst */
value, /* src */
P_TILE * sizeof(float), /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
P_TILE * sizeof(float), /* src_stride */
S_TILE / num_cores); /* repetitions */
S_TILE ); /* repetitions */

snrt_dma_wait_all();

Expand All @@ -461,8 +465,6 @@ static inline void transformer_layer(transformer_layer_t *const l) {
// Compute the storage size of the matrices
uint32_t RowBlock = B_r * l->positional_embeddings * sizeof(float);
uint32_t ColBlock = B_c * l->positional_embeddings * sizeof(float);
uint32_t out_size = l->seq_len * l->positional_embeddings * sizeof(float);
float output_matrix[out_size];

// Reset the TCDM pointer
tcdm_ptr = (float *)snrt_l1_next();
Expand Down Expand Up @@ -502,20 +504,20 @@ static inline void transformer_layer(transformer_layer_t *const l) {
snrt_dma_txid_t txid_Q_tile =
snrt_dma_start_2d(
Q_tile, /* dst */
q_lin + i * B_r * l->positional_embeddings, /* src */
RowBlock, /* size */
&l->Q_lin[i * B_r * l->positional_embeddings], /* src */
l->positional_embeddings * sizeof(float), /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
B_r * l->positional_embeddings * sizeof(float), /* src_stride */
l->positional_embeddings * sizeof(float), /* src_stride */
B_r); /* repetitions */

snrt_dma_wait_all();
}

// Initialize the SoftMax statistics
for (int r = 0; r < B_r; r++) {
m_i[r] = -INFINITY;
m_i_prev[r] = -INFINITY;
l_j[r] = 0.0f;
for (int r = 0; r < B_r / num_cores; r++) {
m_i[r + compute_id * (B_r / num_cores)] = -INFINITY;
m_i_prev[r + compute_id * (B_r / num_cores)] = -INFINITY;
l_j[r + compute_id * (B_r / num_cores)] = 0.0f;
}

snrt_cluster_hw_barrier();
Expand All @@ -528,20 +530,20 @@ static inline void transformer_layer(transformer_layer_t *const l) {
snrt_dma_txid_t txid_K_tile =
snrt_dma_start_2d(
K_tile, /* dst */
k_lin + j * B_c * l->positional_embeddings,/* src */
ColBlock, /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
B_c * l->positional_embeddings * sizeof(float),/* src_stride */
B_c); /* repetitions */
&l->K_lin[j * B_c * l->positional_embeddings], /* src */
l->positional_embeddings * sizeof(float), /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
l->positional_embeddings * sizeof(float), /* src_stride */
B_c); /* repetitions */

snrt_dma_txid_t txid_V_tile =
snrt_dma_start_2d(
V_tile, /* dst */
v_lin + j * B_c * l->positional_embeddings,/* src */
ColBlock, /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
B_c * l->positional_embeddings * sizeof(float),/* src_stride */
B_c); /* repetitions */
&l->V_lin[j * B_c * l->positional_embeddings], /* src */
l->positional_embeddings * sizeof(float), /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
l->positional_embeddings * sizeof(float), /* src_stride */
B_c); /* repetitions */

snrt_dma_wait_all();
}
Expand All @@ -565,7 +567,7 @@ static inline void transformer_layer(transformer_layer_t *const l) {
// we parallelize the rows of Q_tile
gemm_fp32_baseline(B_r / num_cores, B_c, l->positional_embeddings,
&Q_tile[row_offset], l->positional_embeddings, 0, K_tile,
l->positional_embeddings, 0, A_tile, B_c, 0.0f);
l->positional_embeddings, 1, &A_tile[compute_id * (B_r / num_cores) * B_c], B_c, 0.0f);

// print the current attention matrix for debugging
if (DEBUG_VERBOSE == 1) {
Expand Down Expand Up @@ -638,8 +640,12 @@ static inline void transformer_layer(transformer_layer_t *const l) {
for (int k = 0; k < B_r / num_cores; k++) {
for (int p = 0; p < l->positional_embeddings; p++) {
// the current row will be scaled by the shifted exponential
O_tile[O_offset + k * l->positional_embeddings + p] =
(1 / shifted_exp[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p];
if (i == 0) {
O_tile[O_offset + k * l->positional_embeddings + p] = 0.0f;
} else {
O_tile[O_offset + k * l->positional_embeddings + p] =
(1 / shifted_exp[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p];
}
// we compute the matrix multiplication of P_tilde and V_tile
for (int o = 0; o < B_c; o++) {
O_tile[O_offset + k * (l->positional_embeddings) + p] +=
Expand Down Expand Up @@ -670,6 +676,7 @@ static inline void transformer_layer(transformer_layer_t *const l) {
snrt_cluster_hw_barrier();

// we write back O_i as the i-th block row of the output matrix O to DRAM
// O_i is og size B_r x E
if (!snrt_is_compute_core()) {
// printf("Writing back O_tile %d/%d\n", i, T_r);
// determine the current offset in the output matrix
Expand All @@ -678,16 +685,34 @@ static inline void transformer_layer(transformer_layer_t *const l) {
// write back the output matrix
snrt_dma_txid_t txid_output =
snrt_dma_start_2d(
output_matrix + output_offset, /* dst */
&l->O[output_offset], /* dst */
O_tile, /* src */
B_r * l->positional_embeddings * sizeof(float), /* size */
l->positional_embeddings * sizeof(float), /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
B_r * l->positional_embeddings * sizeof(float), /* src_stride */
1); /* repetitions */
l->positional_embeddings * sizeof(float), /* src_stride */
B_r); /* repetitions */

snrt_dma_wait_all();
}
} // end of T_r loop

snrt_cluster_hw_barrier();

// Self Attention
// determine the tile sizes
uint32_t s_Size = l->seq_len * sizeof(float);
int32_t num_s_tiles = l->seq_len / s_Size;
uint32_t e_Size = l->embeddings * sizeof(float) / 2;
int32_t num_e_tiles = l->embeddings / e_Size;

snrt_cluster_hw_barrier();

/////////////////////////////////////////////////
//// MULTI-LAYER PERCEPTRON ////
/////////////////////////////////////////////////

// reset the TCDM pointer
tcdm_ptr = (float *)snrt_l1_next();

snrt_global_barrier();
}
40 changes: 32 additions & 8 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,45 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
} else if (ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register float c0 = ALPHA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n];
// register float c0 = ALPHA * C[m * ldC + n];
if (ALPHA == 0.0f) {
c0 = 0.0f;
} else {
c0 = ALPHA * C[m * ldC + n];
}
C[m * ldC + n] = c0;
c1 = 0.0f;
c2 = 0.0f;
c3 = 0.0f;
for (uint32_t k = 0; k < K; k+=4) {
c0 += A[(k + 0) * M * ldA + m * ldA] * B[(k + 0) * ldB + n];
c1 += A[(k + 1) * M * ldA + m * ldA] * B[(k + 1) * ldB + n];
c2 += A[(k + 2) * M * ldA + m * ldA] * B[(k + 2) * ldB + n];
c3 += A[(k + 3) * M * ldA + m * ldA] * B[(k + 3) * ldB + n];
}
C[m * ldC + n] = c0 + c1 + c2 + c3;
}
}
} else if (!ta && tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register float c0 = ALPHA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
c0 += A[k + m * ldA] * B[k + n * ldB];
// register float c0 = ALPHA * C[m * ldC + n];
if (ALPHA == 0.0f) {
c0 = 0.0f;
} else {
c0 = ALPHA * C[m * ldC + n];
}
C[m * ldC + n] = c0;
c1 = 0.0f;
c2 = 0.0f;
c3 = 0.0f;
for (uint32_t k = 0; k < K; k+=4) {
// c0 += A[k + m * ldA] * B[k + n * ldB];
c0 += A[(k + 0) + m * ldA] * B[(k + 0) + n * ldB];
c1 += A[(k + 1) + m * ldA] * B[(k + 1) + n * ldB];
c2 += A[(k + 2) + m * ldA] * B[(k + 2) + n * ldB];
c3 += A[(k + 3) + m * ldA] * B[(k + 3) + n * ldB];
}
// C[m * ldC + n] = c0;
C[m * ldC + n] = c0 + c1 + c2 + c3;
}
}
} else {
Expand Down
5 changes: 5 additions & 0 deletions target/snitch_cluster/sw/apps/transformer/src/transformer.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ int main() {
transformer_l.weights_q = (float *)transformer_weights_q_dram;
transformer_l.weights_k = (float *)transformer_weights_k_dram;
transformer_l.weights_v = (float *)transformer_weights_v_dram;
transformer_l.weights_o = (float *)transformer_weights_o_dram;
transformer_l.Q_lin = (float *)transformer_Q_lin_dram;
transformer_l.K_lin = (float *)transformer_K_lin_dram;
transformer_l.V_lin = (float *)transformer_V_lin_dram;
transformer_l.O = (float *)transformer_O_dram;
// Results of query, key, value computation
// transformer_l.query = (double *)transformer_query_dram;
// transformer_l.key = (double *)transformer_key_dram;
Expand Down

0 comments on commit f7e07b4

Please sign in to comment.