From 240fea0bb8605938cc16aa3181a82063c8e13de1 Mon Sep 17 00:00:00 2001 From: Viviane Potocnik Date: Fri, 28 Jun 2024 11:21:02 +0200 Subject: [PATCH] [sw] FCL: explicitely declare struct fields --- .../src/fused_concat_linear.h | 97 ++++++++++--------- 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/sw/dnn/fused_concat_linear/src/fused_concat_linear.h b/sw/dnn/fused_concat_linear/src/fused_concat_linear.h index d1ba0d15b..22f7a2ec8 100644 --- a/sw/dnn/fused_concat_linear/src/fused_concat_linear.h +++ b/sw/dnn/fused_concat_linear/src/fused_concat_linear.h @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 // // Luca Colagrande +// Viviane Potocnik #include "snrt.h" @@ -47,28 +48,32 @@ static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) { uint32_t k = l.input_shape[1] * l.num_inputs; uint32_t n = l.output_shape[1]; - gemm_args_t gemm_args = {1.0, - l.dtype, - 0, - 1, - 0, - snrt_cluster_num(), - 1, - 1, - 0, - 1, - 1, - 0, - 0, - m, - n, - k, - l.concat_output, - l.weights, - 0, - l.linear_output, - l.gemm_implementation}; - gemm(&gemm_args); + gemm_args_t gemm_args; + gemm_args_t *local_args = (gemm_args_t *)&gemm_args; + + local_args->alpha = 1.0; + local_args->prec = l.dtype; + local_args->setup_ssr = 0; + local_args->parallelize_m = 1; + local_args->parallelize_k = 0; + local_args->m_tiles = snrt_cluster_core_num(); + local_args->n_tiles = 1; + local_args->k_tiles = 1; + local_args->load_a = 0; + local_args->load_b = 1; + local_args->load_c = 1; + local_args->transa = 0; + local_args->transb = 0; + local_args->M = m; + local_args->N = n; + local_args->K = k; + local_args->a = l.concat_output; + local_args->b = l.weights; + local_args->beta = 0; + local_args->c = l.linear_output; + local_args->gemm_fp = l.gemm_implementation; + + gemm(&local_args); snrt_global_barrier(); @@ -92,28 +97,32 @@ static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t l) { } snrt_cluster_hw_barrier(); - gemm_args_t gemm_args = {1.0, - l.dtype, - 0, - 0, - 1, - 1, - 1, - l.num_inputs, - 0, - 1, - 1, - 0, - 0, - m, - n, - concat_k, - a, - l.weights, - 0, - l.linear_output, - l.gemm_implementation}; - gemm(&gemm_args); + gemm_args_t gemm_args; + gemm_args_t *local_args = (gemm_args_t *)&gemm_args; + + local_args->alpha = 1.0; + local_args->prec = l.dtype; + local_args->setup_ssr = 0; + local_args->parallelize_m = 0; + local_args->parallelize_k = 1; + local_args->m_tiles = 1; + local_args->n_tiles = 1; + local_args->k_tiles = l.num_inputs; + local_args->load_a = 0; + local_args->load_b = 1; + local_args->load_c = 1; + local_args->transa = 0; + local_args->transb = 0; + local_args->M = m; + local_args->N = n; + local_args->K = concat_k; + local_args->a = a; + local_args->b = l.weights; + local_args->beta = 0; + local_args->c = l.linear_output; + local_args->gemm_fp = l.gemm_implementation; + + gemm(&local_args); snrt_global_barrier();