Skip to content

Commit

Permalink
transformer: changes to support FP64 transformer network
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Oct 16, 2023
1 parent 9f67a23 commit 4c77012
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 32 deletions.
6 changes: 3 additions & 3 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ typedef __fp16 v4f16 __attribute__((vector_size(8)));
typedef char v8f8 __attribute__((vector_size(8)));
#endif

dump_float(query, 11); // = 0xb
dump_uint(index, 9); // 14 = 0xe
dump_float(gemm, 8);
dump_uint(index, 9);


void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
Expand All @@ -31,7 +31,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
if (!ta && !tb) {
for (uint32_t m = 0; m < M; m++) {
for (uint32_t n = 0; n < N; n++) {
register double c0 = ALPHA * C[m * ldC + n];
double c0 = ALPHA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
c0 += A[k + m * ldA] * B[k * ldB + n];
}
Expand Down
94 changes: 94 additions & 0 deletions sw/dnn/src/layernorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// #include "printf.h"
#include "utils.h"

// add dump function for layernorm
dump_float(ln, 5);

/**
* @struct layernorm_layer_struct
* @brief This structure contains all parameters necessary
Expand Down Expand Up @@ -87,6 +90,97 @@ static inline void layernorm_fp32(float *input, float *output, int32_t ldI,
snrt_cluster_hw_barrier();
}

/**
* Implementation of the LayerNorm layer for the Transformer model for FP64.
*/
static inline void transformer_layernorm_fp64(double *input, int32_t ldI,
int32_t seq_len, int32_t embeddings,
int32_t eps) {
double mean = 0.0; // max value of the current core
double var = 0.0; // sum of the exp values of the current core

uint32_t compute_id = snrt_global_core_idx();
uint32_t num_cores = snrt_cluster_compute_core_num();

for (int32_t s = 0; s < seq_len; s++) {
mean = 0.0;
var = 0.0;

for (int32_t i = 0; i < embeddings; i++) {
mean += input[s * ldI + i];
}
mean /= embeddings;

// printf("mean[%d] = %f\n", b, mean);

for (int32_t i = 0; i < embeddings; i++) {
var += (input[s * ldI + i] - mean) *
(input[s * ldI + i] - mean);
}
var /= embeddings;

// printf("var[%d] = %f\n", b, var);

// compute the shifted value of the current row
for (int32_t i = 0; i < embeddings; i++) {
input[s * ldI + i] =
(input[s * ldI + i] - mean) /
sqrtf(var + eps);
// printf("output[%d][%d][%d] = %f\n", b, s + compute_id, i,
// output[s * ldI + i]);
// dump_ln(input[s * ldI + i]);
}
}

snrt_cluster_hw_barrier();
}


/**
* Implementation of the LayerNorm layer for the Transformer model for FP32.
*/
static inline void transformer_layernorm_fp32(float *input, int32_t ldI,
int32_t seq_len, int32_t embeddings,
int32_t eps) {
float mean = 0.0; // max value of the current core
float var = 0.0; // sum of the exp values of the current core

uint32_t compute_id = snrt_global_core_idx();
uint32_t num_cores = snrt_cluster_compute_core_num();

for (int32_t s = 0; s < seq_len; s++) {
mean = 0.0;
var = 0.0;

for (int32_t i = 0; i < embeddings; i++) {
mean += input[s * ldI + i];
}
mean /= embeddings;

// printf("mean[%d] = %f\n", b, mean);

for (int32_t i = 0; i < embeddings; i++) {
var += (input[s * ldI + i] - mean) *
(input[s * ldI + i] - mean);
}
var /= embeddings;

// printf("var[%d] = %f\n", b, var);

// compute the shifted value of the current row
for (int32_t i = 0; i < embeddings; i++) {
input[s * ldI + i] =
(input[s * ldI + i] - mean) /
sqrtf(var + eps);
// printf("output[%d][%d][%d] = %f\n", b, s + compute_id, i,
// output[s * ldI + i]);
// dump_ln(input[s * ldI + i]);
}
}

snrt_cluster_hw_barrier();
}

/**
* @brief layernorm layer
*
Expand Down
45 changes: 33 additions & 12 deletions target/snitch_cluster/sw/apps/transformer/src/data.h

Large diffs are not rendered by default.

59 changes: 42 additions & 17 deletions target/snitch_cluster/sw/apps/transformer/src/transformer.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,48 @@
#include "data.h"

int main() {
// Input data
transformer_l.ifmap = (float *)transformer_ifmap_dram;
// Weights for query, key, value
transformer_l.weights_q = (float *)transformer_weights_q_dram;
transformer_l.weights_k = (float *)transformer_weights_k_dram;
transformer_l.weights_v = (float *)transformer_weights_v_dram;
transformer_l.weights_o = (float *)transformer_weights_o_dram;
transformer_l.Q_lin = (float *)transformer_Q_lin_dram;
transformer_l.K_lin = (float *)transformer_K_lin_dram;
transformer_l.V_lin = (float *)transformer_V_lin_dram;
transformer_l.O = (float *)transformer_O_dram;
// Results of query, key, value computation
// transformer_l.query = (double *)transformer_query_dram;
// transformer_l.key = (double *)transformer_key_dram;
// transformer_l.value = (double *)transformer_value_dram;

transformer_layer(&transformer_l);

switch(transformer_l.dtype) {
case FP64:
// Input data
transformer_l.ifmap = (double *)transformer_ifmap_dram;
// Weights for query, key, value
transformer_l.weights_q = (double *)transformer_weights_q_dram;
transformer_l.weights_k = (double *)transformer_weights_k_dram;
transformer_l.weights_v = (double *)transformer_weights_v_dram;
// transformer_l.weights_o = (double *)transformer_weights_o_dram;
// Write back location for output in DRAM
transformer_l.Q_lin = (double *)transformer_Q_lin_dram;
transformer_l.K_lin = (double *)transformer_K_lin_dram;
transformer_l.V_lin = (double *)transformer_V_lin_dram;
// Matrices for FlashAttention-2
transformer_l.Q_fa = (double *)transformer_q_fa_dram;
transformer_l.K_fa = (double *)transformer_k_fa_dram;
transformer_l.V_fa = (double *)transformer_v_fa_dram;
transformer_l.O = (double *)transformer_O_dram;

transformer_layer_fp64(&transformer_l);
break;

case FP32:
// Input data
transformer_l.ifmap = (float *)transformer_ifmap_dram;
// Weights for query, key, value
// transformer_l.weights_q = (float *)transformer_weights_q_dram;
// transformer_l.weights_k = (float *)transformer_weights_k_dram;
// transformer_l.weights_v = (float *)transformer_weights_v_dram;
// transformer_l.weights_o = (float *)transformer_weights_o_dram;
// Write back location for output in DRAM
transformer_l.Q_lin = (float *)transformer_Q_lin_dram;
transformer_l.K_lin = (float *)transformer_K_lin_dram;
transformer_l.V_lin = (float *)transformer_V_lin_dram;
transformer_l.O = (float *)transformer_O_dram;

transformer_layer_fp32(&transformer_l);
break;


}

snrt_global_barrier();

Expand Down

0 comments on commit 4c77012

Please sign in to comment.