diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index d92b8d26fa..88e8766bdd 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -58,6 +58,11 @@ typedef struct transformer_layer_struct { float *weights_q; float *weights_k; float *weights_v; + float *weights_o; + float *Q_lin; + float *K_lin; + float *V_lin; + float *O; float *ofmap; float *query; float *key; @@ -83,6 +88,10 @@ static inline void transformer_layer(transformer_layer_t *const l) { uint32_t compute_id = snrt_global_core_idx(); uint32_t num_cores = snrt_cluster_compute_core_num(); + ///////////////////////////////////////////////////////////////////// + ////////////// MULTI-HEAD SELF-ATTENTION LAYER ///////////////// + ///////////////////////////////////////////////////////////////////// + uint32_t S_TILE = 24; uint32_t P_TILE = 4; @@ -102,13 +111,7 @@ static inline void transformer_layer(transformer_layer_t *const l) { uint32_t q_lin_size = l->seq_len * l->positional_embeddings * sizeof(float); uint32_t k_lin_size = l->positional_embeddings * l->seq_len * sizeof(float); uint32_t v_lin_size = l->positional_embeddings * l->seq_len * sizeof(float); - uint32_t attention_size = l->seq_len * l->seq_len * sizeof(float); uint32_t output_size = l->seq_len * l->positional_embeddings * sizeof(float); - float *q_lin[q_lin_size]; - float *k_lin[k_lin_size]; - float *v_lin[v_lin_size]; - float *attention_out[attention_size]; - float *output[output_size]; void *tcdm_ptr = (float *)snrt_l1_next(); float *ifmap = tcdm_ptr; @@ -150,7 +153,7 @@ static inline void transformer_layer(transformer_layer_t *const l) { snrt_dma_start_2d( ifmap, /* dst */ l->ifmap + ifmap_offset, /* src */ - ifmap_size, /* size */ + S_TILE * sizeof(float), /* size */ l->embeddings, /* dst_stride */ l->embeddings, /* src_stride */ l->seq_len); /* repetitions */ @@ -291,6 +294,7 @@ static inline void transformer_layer(transformer_layer_t *const l) { uint32_t row_offset = (s_tile * S_TILE + (S_TILE / num_cores) * compute_id) * l->embeddings; // compute a tile of the query matrix + // ifmap (S_TILE x E) * weights_q (E x P) = query (S_TILE x P) if (p_tile == 0) { gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, &ifmap[row_offset], l->embeddings, 0, weights_q, @@ -414,32 +418,32 @@ static inline void transformer_layer(transformer_layer_t *const l) { // write back the query matrix snrt_dma_txid_t txid_q_lin = snrt_dma_start_2d( - q_lin + q_lin_offset, /* dst */ + &l->Q_lin[q_lin_offset], /* dst */ query, /* src */ P_TILE * sizeof(float), /* size */ l->positional_embeddings * sizeof(float), /* dst_stride */ P_TILE * sizeof(float), /* src_stride */ - S_TILE / num_cores); /* repetitions */ + S_TILE ); /* repetitions */ // write back the key matrix snrt_dma_txid_t txid_k_lin = snrt_dma_start_2d( - k_lin + q_lin_offset, /* dst */ + &l->K_lin[q_lin_offset], /* dst */ key, /* src */ P_TILE * sizeof(float), /* size */ l->positional_embeddings * sizeof(float), /* dst_stride */ P_TILE * sizeof(float), /* src_stride */ - S_TILE / num_cores); /* repetitions */ + S_TILE ); /* repetitions */ // write back the value matrix snrt_dma_txid_t txid_v_lin = snrt_dma_start_2d( - v_lin + q_lin_offset, /* dst */ + &l->V_lin[q_lin_offset], /* dst */ value, /* src */ P_TILE * sizeof(float), /* size */ l->positional_embeddings * sizeof(float), /* dst_stride */ P_TILE * sizeof(float), /* src_stride */ - S_TILE / num_cores); /* repetitions */ + S_TILE ); /* repetitions */ snrt_dma_wait_all(); @@ -461,8 +465,6 @@ static inline void transformer_layer(transformer_layer_t *const l) { // Compute the storage size of the matrices uint32_t RowBlock = B_r * l->positional_embeddings * sizeof(float); uint32_t ColBlock = B_c * l->positional_embeddings * sizeof(float); - uint32_t out_size = l->seq_len * l->positional_embeddings * sizeof(float); - float output_matrix[out_size]; // Reset the TCDM pointer tcdm_ptr = (float *)snrt_l1_next(); @@ -502,20 +504,20 @@ static inline void transformer_layer(transformer_layer_t *const l) { snrt_dma_txid_t txid_Q_tile = snrt_dma_start_2d( Q_tile, /* dst */ - q_lin + i * B_r * l->positional_embeddings, /* src */ - RowBlock, /* size */ + &l->Q_lin[i * B_r * l->positional_embeddings], /* src */ + l->positional_embeddings * sizeof(float), /* size */ l->positional_embeddings * sizeof(float), /* dst_stride */ - B_r * l->positional_embeddings * sizeof(float), /* src_stride */ + l->positional_embeddings * sizeof(float), /* src_stride */ B_r); /* repetitions */ snrt_dma_wait_all(); } // Initialize the SoftMax statistics - for (int r = 0; r < B_r; r++) { - m_i[r] = -INFINITY; - m_i_prev[r] = -INFINITY; - l_j[r] = 0.0f; + for (int r = 0; r < B_r / num_cores; r++) { + m_i[r + compute_id * (B_r / num_cores)] = -INFINITY; + m_i_prev[r + compute_id * (B_r / num_cores)] = -INFINITY; + l_j[r + compute_id * (B_r / num_cores)] = 0.0f; } snrt_cluster_hw_barrier(); @@ -528,20 +530,20 @@ static inline void transformer_layer(transformer_layer_t *const l) { snrt_dma_txid_t txid_K_tile = snrt_dma_start_2d( K_tile, /* dst */ - k_lin + j * B_c * l->positional_embeddings,/* src */ - ColBlock, /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - B_c * l->positional_embeddings * sizeof(float),/* src_stride */ - B_c); /* repetitions */ + &l->K_lin[j * B_c * l->positional_embeddings], /* src */ + l->positional_embeddings * sizeof(float), /* size */ + l->positional_embeddings * sizeof(float), /* dst_stride */ + l->positional_embeddings * sizeof(float), /* src_stride */ + B_c); /* repetitions */ snrt_dma_txid_t txid_V_tile = snrt_dma_start_2d( V_tile, /* dst */ - v_lin + j * B_c * l->positional_embeddings,/* src */ - ColBlock, /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - B_c * l->positional_embeddings * sizeof(float),/* src_stride */ - B_c); /* repetitions */ + &l->V_lin[j * B_c * l->positional_embeddings], /* src */ + l->positional_embeddings * sizeof(float), /* size */ + l->positional_embeddings * sizeof(float), /* dst_stride */ + l->positional_embeddings * sizeof(float), /* src_stride */ + B_c); /* repetitions */ snrt_dma_wait_all(); } @@ -565,7 +567,7 @@ static inline void transformer_layer(transformer_layer_t *const l) { // we parallelize the rows of Q_tile gemm_fp32_baseline(B_r / num_cores, B_c, l->positional_embeddings, &Q_tile[row_offset], l->positional_embeddings, 0, K_tile, - l->positional_embeddings, 0, A_tile, B_c, 0.0f); + l->positional_embeddings, 1, &A_tile[compute_id * (B_r / num_cores) * B_c], B_c, 0.0f); // print the current attention matrix for debugging if (DEBUG_VERBOSE == 1) { @@ -638,8 +640,12 @@ static inline void transformer_layer(transformer_layer_t *const l) { for (int k = 0; k < B_r / num_cores; k++) { for (int p = 0; p < l->positional_embeddings; p++) { // the current row will be scaled by the shifted exponential - O_tile[O_offset + k * l->positional_embeddings + p] = - (1 / shifted_exp[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p]; + if (i == 0) { + O_tile[O_offset + k * l->positional_embeddings + p] = 0.0f; + } else { + O_tile[O_offset + k * l->positional_embeddings + p] = + (1 / shifted_exp[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p]; + } // we compute the matrix multiplication of P_tilde and V_tile for (int o = 0; o < B_c; o++) { O_tile[O_offset + k * (l->positional_embeddings) + p] += @@ -670,6 +676,7 @@ static inline void transformer_layer(transformer_layer_t *const l) { snrt_cluster_hw_barrier(); // we write back O_i as the i-th block row of the output matrix O to DRAM + // O_i is og size B_r x E if (!snrt_is_compute_core()) { // printf("Writing back O_tile %d/%d\n", i, T_r); // determine the current offset in the output matrix @@ -678,16 +685,34 @@ static inline void transformer_layer(transformer_layer_t *const l) { // write back the output matrix snrt_dma_txid_t txid_output = snrt_dma_start_2d( - output_matrix + output_offset, /* dst */ + &l->O[output_offset], /* dst */ O_tile, /* src */ - B_r * l->positional_embeddings * sizeof(float), /* size */ + l->positional_embeddings * sizeof(float), /* size */ l->positional_embeddings * sizeof(float), /* dst_stride */ - B_r * l->positional_embeddings * sizeof(float), /* src_stride */ - 1); /* repetitions */ + l->positional_embeddings * sizeof(float), /* src_stride */ + B_r); /* repetitions */ snrt_dma_wait_all(); } } // end of T_r loop + snrt_cluster_hw_barrier(); + + // Self Attention + // determine the tile sizes + uint32_t s_Size = l->seq_len * sizeof(float); + int32_t num_s_tiles = l->seq_len / s_Size; + uint32_t e_Size = l->embeddings * sizeof(float) / 2; + int32_t num_e_tiles = l->embeddings / e_Size; + + snrt_cluster_hw_barrier(); + + ///////////////////////////////////////////////// + //// MULTI-LAYER PERCEPTRON //// + ///////////////////////////////////////////////// + + // reset the TCDM pointer + tcdm_ptr = (float *)snrt_l1_next(); + snrt_global_barrier(); } \ No newline at end of file diff --git a/sw/blas/gemm/src/gemm.h b/sw/blas/gemm/src/gemm.h index 5417f14f59..dfdebbf0bc 100644 --- a/sw/blas/gemm/src/gemm.h +++ b/sw/blas/gemm/src/gemm.h @@ -112,21 +112,45 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A, } else if (ta && !tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - register float c0 = ALPHA * C[m * ldC + n]; - for (uint32_t k = 0; k < K; k++) { - c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n]; + // register float c0 = ALPHA * C[m * ldC + n]; + if (ALPHA == 0.0f) { + c0 = 0.0f; + } else { + c0 = ALPHA * C[m * ldC + n]; } - C[m * ldC + n] = c0; + c1 = 0.0f; + c2 = 0.0f; + c3 = 0.0f; + for (uint32_t k = 0; k < K; k+=4) { + c0 += A[(k + 0) * M * ldA + m * ldA] * B[(k + 0) * ldB + n]; + c1 += A[(k + 1) * M * ldA + m * ldA] * B[(k + 1) * ldB + n]; + c2 += A[(k + 2) * M * ldA + m * ldA] * B[(k + 2) * ldB + n]; + c3 += A[(k + 3) * M * ldA + m * ldA] * B[(k + 3) * ldB + n]; + } + C[m * ldC + n] = c0 + c1 + c2 + c3; } } } else if (!ta && tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - register float c0 = ALPHA * C[m * ldC + n]; - for (uint32_t k = 0; k < K; k++) { - c0 += A[k + m * ldA] * B[k + n * ldB]; + // register float c0 = ALPHA * C[m * ldC + n]; + if (ALPHA == 0.0f) { + c0 = 0.0f; + } else { + c0 = ALPHA * C[m * ldC + n]; } - C[m * ldC + n] = c0; + c1 = 0.0f; + c2 = 0.0f; + c3 = 0.0f; + for (uint32_t k = 0; k < K; k+=4) { + // c0 += A[k + m * ldA] * B[k + n * ldB]; + c0 += A[(k + 0) + m * ldA] * B[(k + 0) + n * ldB]; + c1 += A[(k + 1) + m * ldA] * B[(k + 1) + n * ldB]; + c2 += A[(k + 2) + m * ldA] * B[(k + 2) + n * ldB]; + c3 += A[(k + 3) + m * ldA] * B[(k + 3) + n * ldB]; + } + // C[m * ldC + n] = c0; + C[m * ldC + n] = c0 + c1 + c2 + c3; } } } else { diff --git a/target/snitch_cluster/sw/apps/transformer/src/transformer.c b/target/snitch_cluster/sw/apps/transformer/src/transformer.c index 128db110ac..b390bd9932 100644 --- a/target/snitch_cluster/sw/apps/transformer/src/transformer.c +++ b/target/snitch_cluster/sw/apps/transformer/src/transformer.c @@ -17,6 +17,11 @@ int main() { transformer_l.weights_q = (float *)transformer_weights_q_dram; transformer_l.weights_k = (float *)transformer_weights_k_dram; transformer_l.weights_v = (float *)transformer_weights_v_dram; + transformer_l.weights_o = (float *)transformer_weights_o_dram; + transformer_l.Q_lin = (float *)transformer_Q_lin_dram; + transformer_l.K_lin = (float *)transformer_K_lin_dram; + transformer_l.V_lin = (float *)transformer_V_lin_dram; + transformer_l.O = (float *)transformer_O_dram; // Results of query, key, value computation // transformer_l.query = (double *)transformer_query_dram; // transformer_l.key = (double *)transformer_key_dram;