diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index f3af1c86fd..0b25832a82 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -83,6 +83,61 @@ dump_float(ifmap, 6); dump_float(weights, 10); // = 0xa dump_float(value, 12); // = 0xc +/** + * Implementation of the GELU layer + */ +#define M_PI 3.14159265358979323846 +static inline float transformer_gelu_fp32(float x) { + float y = 0.5 * x * (1.0 + tanh(sqrt(2.0 / M_PI) * (x + 0.044715 * x * x * x))); + return y; +} + +/** + * Implementation of the LayerNorm layer fused with a linear layer and a GELU activation. + * Input is a 2D matrix of size (S x E) where S is the sequence length and E is the number of embeddings. + * The weights matrix is of size (E x F) where F is the number of hidden nodes. + */ +static inline void fused_mlp_baseline(float *input, float *output, int32_t ldI, int32_t ldO, float *weights, int32_t ldW, + int32_t seq_len, int32_t embeddings, int32_t ff_len, int32_t eps) { + float mean = 0.0; // max value of the current core + float var = 0.0; // sum of the exp values of the current core + float acc = 0.0; // accumulator for the linear layer + + uint32_t compute_id = snrt_global_core_idx(); + uint32_t num_cores = snrt_cluster_compute_core_num(); + + // compute the mean and variance along the innermost dimension + for (int s = 0; s < seq_len; s++) { + mean = 0.0; + var = 0.0; + for (int e = 0; e < embeddings; e++) { + mean += input[s * ldI + e]; + } + + mean /= embeddings; + + for (int e = 0; e < embeddings; e++) { + var += (input[s * ldI + e] - mean) * (input[s * ldI + e] - mean); + } + var /= embeddings; + // we have to compute the normalize row only once + // and then multiply with the columns of the weight matrix + for (int f = 0; f < ff_len; f++) { + acc = 0.0; + for (int e = 0; e < embeddings; e++) { + // we only have to compute the normalization once + if (f == 0) { + output[s * ldO + f] = (input[s * ldI + e] - mean) / sqrtf(var + eps); + } + acc += output[s * ldO + e] * weights[e * ldW + f]; + } + output[s * ldO + f] = transformer_gelu_fp32(acc); + } + } + + +} + /** * @brief Transformer layer *