Skip to content

Commit

Permalink
[sw] FCL: explicitely declare struct fields
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Jun 28, 2024
1 parent 7a99f0a commit 240fea0
Showing 1 changed file with 53 additions and 44 deletions.
97 changes: 53 additions & 44 deletions sw/dnn/fused_concat_linear/src/fused_concat_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0
//
// Luca Colagrande <[email protected]>
// Viviane Potocnik <[email protected]>

#include "snrt.h"

Expand Down Expand Up @@ -47,28 +48,32 @@ static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) {
uint32_t k = l.input_shape[1] * l.num_inputs;
uint32_t n = l.output_shape[1];

gemm_args_t gemm_args = {1.0,
l.dtype,
0,
1,
0,
snrt_cluster_num(),
1,
1,
0,
1,
1,
0,
0,
m,
n,
k,
l.concat_output,
l.weights,
0,
l.linear_output,
l.gemm_implementation};
gemm(&gemm_args);
gemm_args_t gemm_args;
gemm_args_t *local_args = (gemm_args_t *)&gemm_args;

local_args->alpha = 1.0;
local_args->prec = l.dtype;
local_args->setup_ssr = 0;
local_args->parallelize_m = 1;
local_args->parallelize_k = 0;
local_args->m_tiles = snrt_cluster_core_num();
local_args->n_tiles = 1;
local_args->k_tiles = 1;
local_args->load_a = 0;
local_args->load_b = 1;
local_args->load_c = 1;
local_args->transa = 0;
local_args->transb = 0;
local_args->M = m;
local_args->N = n;
local_args->K = k;
local_args->a = l.concat_output;
local_args->b = l.weights;
local_args->beta = 0;
local_args->c = l.linear_output;
local_args->gemm_fp = l.gemm_implementation;

gemm(&local_args);

snrt_global_barrier();

Expand All @@ -92,28 +97,32 @@ static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t l) {
}
snrt_cluster_hw_barrier();

gemm_args_t gemm_args = {1.0,
l.dtype,
0,
0,
1,
1,
1,
l.num_inputs,
0,
1,
1,
0,
0,
m,
n,
concat_k,
a,
l.weights,
0,
l.linear_output,
l.gemm_implementation};
gemm(&gemm_args);
gemm_args_t gemm_args;
gemm_args_t *local_args = (gemm_args_t *)&gemm_args;

local_args->alpha = 1.0;
local_args->prec = l.dtype;
local_args->setup_ssr = 0;
local_args->parallelize_m = 0;
local_args->parallelize_k = 1;
local_args->m_tiles = 1;
local_args->n_tiles = 1;
local_args->k_tiles = l.num_inputs;
local_args->load_a = 0;
local_args->load_b = 1;
local_args->load_c = 1;
local_args->transa = 0;
local_args->transb = 0;
local_args->M = m;
local_args->N = n;
local_args->K = concat_k;
local_args->a = a;
local_args->b = l.weights;
local_args->beta = 0;
local_args->c = l.linear_output;
local_args->gemm_fp = l.gemm_implementation;

gemm(&local_args);

snrt_global_barrier();

Expand Down

0 comments on commit 240fea0

Please sign in to comment.