Skip to content

Commit

Permalink
Fixup nnx_quant_t nnx_norm_t
Browse files Browse the repository at this point in the history
  • Loading branch information
lukamac committed Jan 16, 2024
1 parent 94bbe08 commit 6341d0e
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions test/app/src/nnx_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,18 @@ typedef neureka_siracusa_conf_t nnx_bsp_conf_t;
#include "weight.h"

static void task_prepare(nnx_task_t *task) {
nnx_task_init(task, WEIGHT_HEIGHT, GROUPS > 1, INPUT_BITS, OUTPUT_BITS,
WEIGHT_BITS, weightOffsetModeLayerWise, WEIGHT_OFFSET,
(neureka_quant_t){.shift_amount = OUTSHIFT,
.mode = quantMode8Bit,
.function = HAS_RELU ? quantFunctionRelu
: quantFunctionIdentity,
.flag_rounding = nnxTaskFlagFalse},
(neureka_norm_t){.mode = normMode8Bit,
.flag_bias = HAS_BIAS ? nnxTaskFlagTrue
: nnxTaskFlagFalse,
.flag_shift = nnxTaskFlagFalse},
STRIDE_HEIGHT);
nnx_task_init(
task, WEIGHT_HEIGHT, GROUPS > 1, INPUT_BITS, OUTPUT_BITS, WEIGHT_BITS,
weightOffsetModeLayerWise, WEIGHT_OFFSET,
(nnx_quant_t){.shift_amount = OUTSHIFT,
.mode = quantMode8Bit,
.function =
HAS_RELU ? quantFunctionRelu : quantFunctionIdentity,
.flag_rounding = nnxTaskFlagFalse},
(nnx_norm_t){.mode = normMode8Bit,
.flag_bias = HAS_BIAS ? nnxTaskFlagTrue : nnxTaskFlagFalse,
.flag_shift = nnxTaskFlagFalse},
STRIDE_HEIGHT);

if (STRIDE_WIDTH == 2 && STRIDE_HEIGHT == 2) {
nnx_task_set_dims_stride2x2(
Expand Down Expand Up @@ -147,8 +147,7 @@ static void task_prepare(nnx_task_t *task) {
static void task_execute(nnx_task_t *task) {
nnx_dev_t *dev = nnx_bsp_get_dev();

nnx_gvsoc_log_activate(dev, NNX_GVSOC_LOG_LEVEL,
NNX_GVSOC_LOG_FORMAT);
nnx_gvsoc_log_activate(dev, NNX_GVSOC_LOG_LEVEL, NNX_GVSOC_LOG_FORMAT);

nnx_bsp_conf_t conf = {.max_stall = 8};
nnx_init(dev, &conf);
Expand Down

0 comments on commit 6341d0e

Please sign in to comment.