Skip to content

Commit

Permalink
[sw] DNN: fix header include order in DNN header
Browse files Browse the repository at this point in the history
[sw] FCL: fix struct declaration
  • Loading branch information
Viviane Potocnik committed Jun 28, 2024
1 parent 240fea0 commit cec80a4
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 56 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ jobs:
# yamllint enable rule:line-length
match_regex: true
exclude_paths: |
sw/dnn/flashattention_2/src/flashattention_2.h
sw/snRuntime/src/omp/interface.h
sw/math/arch/generic/*
sw/math/arch/riscv64/bits/*
Expand Down Expand Up @@ -133,7 +132,7 @@ jobs:
with:
exclude: |
./sw/saris
./sw/dnn/flashattention_2/src/flashattention_2.h
./sw/dnn/src/dnn.h
clangFormatVersion: 10

######################
Expand Down
1 change: 0 additions & 1 deletion sw/dnn/flashattention_2/src/flashattention_2.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ typedef struct {
void *gemm_implementation;
} flashattention_2_layer_t;

#include "../transpose/src/transpose.h"
#include "../flashattention_2/src/flashattention_2_fp16.h"
#include "../flashattention_2/src/flashattention_2_fp32.h"
#include "../flashattention_2/src/flashattention_2_fp8.h"
Expand Down
98 changes: 46 additions & 52 deletions sw/dnn/fused_concat_linear/src/fused_concat_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,32 +48,29 @@ static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) {
uint32_t k = l.input_shape[1] * l.num_inputs;
uint32_t n = l.output_shape[1];

gemm_args_t gemm_args;
gemm_args_t *local_args = (gemm_args_t *)&gemm_args;

local_args->alpha = 1.0;
local_args->prec = l.dtype;
local_args->setup_ssr = 0;
local_args->parallelize_m = 1;
local_args->parallelize_k = 0;
local_args->m_tiles = snrt_cluster_core_num();
local_args->n_tiles = 1;
local_args->k_tiles = 1;
local_args->load_a = 0;
local_args->load_b = 1;
local_args->load_c = 1;
local_args->transa = 0;
local_args->transb = 0;
local_args->M = m;
local_args->N = n;
local_args->K = k;
local_args->a = l.concat_output;
local_args->b = l.weights;
local_args->beta = 0;
local_args->c = l.linear_output;
local_args->gemm_fp = l.gemm_implementation;

gemm(&local_args);
gemm_args_t gemm_args = {.alpha = 1.0,
.prec = l.dtype,
.setup_ssr = 0,
.parallelize_m = 1,
.parallelize_k = 0,
.m_tiles = snrt_cluster_num(),
.n_tiles = 1,
.k_tiles = 1,
.load_a = 0,
.load_b = 1,
.load_c = 1,
.transa = 0,
.transb = 0,
.M = m,
.N = n,
.K = k,
.a = l.concat_output,
.b = l.weights,
.beta = 0,
.c = l.linear_output,
.gemm_fp = l.gemm_implementation};

gemm(&gemm_args);

snrt_global_barrier();

Expand All @@ -97,32 +94,29 @@ static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t l) {
}
snrt_cluster_hw_barrier();

gemm_args_t gemm_args;
gemm_args_t *local_args = (gemm_args_t *)&gemm_args;

local_args->alpha = 1.0;
local_args->prec = l.dtype;
local_args->setup_ssr = 0;
local_args->parallelize_m = 0;
local_args->parallelize_k = 1;
local_args->m_tiles = 1;
local_args->n_tiles = 1;
local_args->k_tiles = l.num_inputs;
local_args->load_a = 0;
local_args->load_b = 1;
local_args->load_c = 1;
local_args->transa = 0;
local_args->transb = 0;
local_args->M = m;
local_args->N = n;
local_args->K = concat_k;
local_args->a = a;
local_args->b = l.weights;
local_args->beta = 0;
local_args->c = l.linear_output;
local_args->gemm_fp = l.gemm_implementation;

gemm(&local_args);
gemm_args_t gemm_args = {.alpha = 1.0,
.prec = l.dtype,
.setup_ssr = 0,
.parallelize_m = 0,
.parallelize_k = 1,
.m_tiles = 1,
.n_tiles = 1,
.k_tiles = l.num_inputs,
.load_a = 0,
.load_b = 1,
.load_c = 1,
.transa = 0,
.transb = 0,
.M = m,
.N = n,
.K = concat_k,
.a = a,
.b = l.weights,
.beta = 0,
.c = l.linear_output,
.gemm_fp = l.gemm_implementation};

gemm(&gemm_args);

snrt_global_barrier();

Expand Down
2 changes: 1 addition & 1 deletion sw/dnn/src/dnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ typedef struct network_single_cluster_t_ {
#include "../batchnorm/src/batchnorm.h"
#include "../concat/src/concat.h"
#include "../conv2d/src/conv2d.h"
#include "../transpose/src/transpose.h"
#include "../flashattention_2/src/flashattention_2.h"
#include "../fused_concat_linear/src/fused_concat_linear.h"
#include "../gelu/src/gelu.h"
#include "../layernorm/src/layernorm.h"
#include "../maxpool/src/maxpool.h"
#include "../softmax/src/softmax.h"
#include "../transpose/src/transpose.h"

0 comments on commit cec80a4

Please sign in to comment.