From cec80a493054186a193d14d479d73b9a0ed71912 Mon Sep 17 00:00:00 2001 From: Viviane Potocnik Date: Fri, 28 Jun 2024 11:23:33 +0200 Subject: [PATCH] [sw] DNN: fix header include order in DNN header [sw] FCL: fix struct declaration --- .github/workflows/lint.yml | 3 +- .../flashattention_2/src/flashattention_2.h | 1 - .../src/fused_concat_linear.h | 98 +++++++++---------- sw/dnn/src/dnn.h | 2 +- 4 files changed, 48 insertions(+), 56 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index feb107bfb7..c208c7843a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -77,7 +77,6 @@ jobs: # yamllint enable rule:line-length match_regex: true exclude_paths: | - sw/dnn/flashattention_2/src/flashattention_2.h sw/snRuntime/src/omp/interface.h sw/math/arch/generic/* sw/math/arch/riscv64/bits/* @@ -133,7 +132,7 @@ jobs: with: exclude: | ./sw/saris - ./sw/dnn/flashattention_2/src/flashattention_2.h + ./sw/dnn/src/dnn.h clangFormatVersion: 10 ###################### diff --git a/sw/dnn/flashattention_2/src/flashattention_2.h b/sw/dnn/flashattention_2/src/flashattention_2.h index ed3662de1a..c5480411b5 100644 --- a/sw/dnn/flashattention_2/src/flashattention_2.h +++ b/sw/dnn/flashattention_2/src/flashattention_2.h @@ -48,7 +48,6 @@ typedef struct { void *gemm_implementation; } flashattention_2_layer_t; -#include "../transpose/src/transpose.h" #include "../flashattention_2/src/flashattention_2_fp16.h" #include "../flashattention_2/src/flashattention_2_fp32.h" #include "../flashattention_2/src/flashattention_2_fp8.h" diff --git a/sw/dnn/fused_concat_linear/src/fused_concat_linear.h b/sw/dnn/fused_concat_linear/src/fused_concat_linear.h index 22f7a2ec89..6cea297cae 100644 --- a/sw/dnn/fused_concat_linear/src/fused_concat_linear.h +++ b/sw/dnn/fused_concat_linear/src/fused_concat_linear.h @@ -48,32 +48,29 @@ static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) { uint32_t k = l.input_shape[1] * l.num_inputs; uint32_t n = l.output_shape[1]; - gemm_args_t gemm_args; - gemm_args_t *local_args = (gemm_args_t *)&gemm_args; - - local_args->alpha = 1.0; - local_args->prec = l.dtype; - local_args->setup_ssr = 0; - local_args->parallelize_m = 1; - local_args->parallelize_k = 0; - local_args->m_tiles = snrt_cluster_core_num(); - local_args->n_tiles = 1; - local_args->k_tiles = 1; - local_args->load_a = 0; - local_args->load_b = 1; - local_args->load_c = 1; - local_args->transa = 0; - local_args->transb = 0; - local_args->M = m; - local_args->N = n; - local_args->K = k; - local_args->a = l.concat_output; - local_args->b = l.weights; - local_args->beta = 0; - local_args->c = l.linear_output; - local_args->gemm_fp = l.gemm_implementation; - - gemm(&local_args); + gemm_args_t gemm_args = {.alpha = 1.0, + .prec = l.dtype, + .setup_ssr = 0, + .parallelize_m = 1, + .parallelize_k = 0, + .m_tiles = snrt_cluster_num(), + .n_tiles = 1, + .k_tiles = 1, + .load_a = 0, + .load_b = 1, + .load_c = 1, + .transa = 0, + .transb = 0, + .M = m, + .N = n, + .K = k, + .a = l.concat_output, + .b = l.weights, + .beta = 0, + .c = l.linear_output, + .gemm_fp = l.gemm_implementation}; + + gemm(&gemm_args); snrt_global_barrier(); @@ -97,32 +94,29 @@ static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t l) { } snrt_cluster_hw_barrier(); - gemm_args_t gemm_args; - gemm_args_t *local_args = (gemm_args_t *)&gemm_args; - - local_args->alpha = 1.0; - local_args->prec = l.dtype; - local_args->setup_ssr = 0; - local_args->parallelize_m = 0; - local_args->parallelize_k = 1; - local_args->m_tiles = 1; - local_args->n_tiles = 1; - local_args->k_tiles = l.num_inputs; - local_args->load_a = 0; - local_args->load_b = 1; - local_args->load_c = 1; - local_args->transa = 0; - local_args->transb = 0; - local_args->M = m; - local_args->N = n; - local_args->K = concat_k; - local_args->a = a; - local_args->b = l.weights; - local_args->beta = 0; - local_args->c = l.linear_output; - local_args->gemm_fp = l.gemm_implementation; - - gemm(&local_args); + gemm_args_t gemm_args = {.alpha = 1.0, + .prec = l.dtype, + .setup_ssr = 0, + .parallelize_m = 0, + .parallelize_k = 1, + .m_tiles = 1, + .n_tiles = 1, + .k_tiles = l.num_inputs, + .load_a = 0, + .load_b = 1, + .load_c = 1, + .transa = 0, + .transb = 0, + .M = m, + .N = n, + .K = concat_k, + .a = a, + .b = l.weights, + .beta = 0, + .c = l.linear_output, + .gemm_fp = l.gemm_implementation}; + + gemm(&gemm_args); snrt_global_barrier(); diff --git a/sw/dnn/src/dnn.h b/sw/dnn/src/dnn.h index 56f91e5897..4673cb8cfb 100644 --- a/sw/dnn/src/dnn.h +++ b/sw/dnn/src/dnn.h @@ -201,10 +201,10 @@ typedef struct network_single_cluster_t_ { #include "../batchnorm/src/batchnorm.h" #include "../concat/src/concat.h" #include "../conv2d/src/conv2d.h" +#include "../transpose/src/transpose.h" #include "../flashattention_2/src/flashattention_2.h" #include "../fused_concat_linear/src/fused_concat_linear.h" #include "../gelu/src/gelu.h" #include "../layernorm/src/layernorm.h" #include "../maxpool/src/maxpool.h" #include "../softmax/src/softmax.h" -#include "../transpose/src/transpose.h"