Skip to content

Commit

Permalink
Add residual add
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Nov 13, 2023
1 parent d6df904 commit a63a3b3
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 52 deletions.
18 changes: 9 additions & 9 deletions c/model.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#define N_BUNDLES 7
Bundle_t bundles [N_BUNDLES] = {
{.n=8 , .l=3 , .kw=11 , .coe=2 , .coe_tl=2 , .r_ll=2 , .h=18 , .w=18 , .ci=3 , .co=8 , .w_kw2=13 , .t=4 , .p=3 , .cm=1 , .cm_p0=1 , .xp_words=6048, .out_buffer_idx=0 , .add_buffer_idx=-1, .w_bpt=140 , .w_bpt_p0=140 , .x_bpt=3032 , .x_bpt_p0=3032 , .o_words=5376 , .o_bytes=2696 , .is_bias=1 , .is_flatten=0 , .b_offset=0 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=0 , .ca_shift=12 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=2 , .ch=9 , .csh_shift=1 , .pkh=3 , .psh=2 , .ph=5 , .psh_shift=1 , .csw=1 , .cw=18 , .csw_shift=0 , .pkw=4 , .psw=3 , .pw=6 , .psw_shift=0 , .pool=POOL_AVG , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 17055749u, .x_header_p0= 17055749u, .w_header= 347372535813u, .w_header_p0= 17055749u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=6 , .t=1 , .p=1 , .cm=20 , .cm_p0=8 , .xp_words=672, .out_buffer_idx=1 , .add_buffer_idx=-1, .w_bpt=104 , .w_bpt_p0=104 , .x_bpt=2696 , .x_bpt_p0=2696 , .o_words=5376 , .o_bytes=2720 , .is_bias=1 , .is_flatten=0 , .b_offset=8 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=12 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81976u, .x_header_p0= 81976u, .w_header= 244276346936u, .w_header_p0= 81976u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=7 , .coe=3 , .coe_tl=2 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=3 , .t=3 , .p=4 , .cm=2 , .cm_p0=2 , .xp_words=672, .out_buffer_idx=0 , .add_buffer_idx=-1, .w_bpt=176 , .w_bpt_p0=176 , .x_bpt=680 , .x_bpt_p0=680 , .o_words=5376 , .o_bytes=2704 , .is_bias=0 , .is_flatten=0 , .b_offset=32 , .b_val_shift=0 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=3 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81931u, .x_header_p0= 81931u, .w_header= 450434777099u, .w_header_p0= 81931u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=5 , .coe=4 , .coe_tl=4 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=4 , .t=2 , .p=2 , .cm=4 , .cm_p0=4 , .xp_words=672, .out_buffer_idx=1 , .add_buffer_idx=-1, .w_bpt=248 , .w_bpt_p0=248 , .x_bpt=1352 , .x_bpt_p0=1352 , .o_words=5376 , .o_bytes=2704 , .is_bias=1 , .is_flatten=0 , .b_offset=32 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81946u, .x_header_p0= 81946u, .w_header= 656593207322u, .w_header_p0= 81946u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=3 , .coe=8 , .coe_tl=8 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=24 , .w_kw2=5 , .t=3 , .p=2 , .cm=6 , .cm_p0=2 , .xp_words=672, .out_buffer_idx=0 , .add_buffer_idx=-1, .w_bpt=224 , .w_bpt_p0=80 , .x_bpt=2024 , .x_bpt_p0=680 , .o_words=16128, .o_bytes=8080 , .is_bias=1 , .is_flatten=0 , .b_offset=40 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=0 , .ca_shift=12 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=24 , .x_header= 81961u, .x_header_p0= 81929u, .w_header= 587873730601u, .w_header_p0= 81929u , .debug_nhwc_words=5760 },
{.n=8 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=5 , .h=5 , .w=6 , .ci=24 , .co=10 , .w_kw2=6 , .t=1 , .p=2 , .cm=20 , .cm_p0=4 , .xp_words=672, .out_buffer_idx=1 , .add_buffer_idx=-1, .w_bpt=248 , .w_bpt_p0=56 , .x_bpt=6728 , .x_bpt_p0=1352 , .o_words=4200 , .o_bytes=2220 , .is_bias=1 , .is_flatten=1 , .b_offset=64 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=1 , .oh=8 , .ow=1 , .oc=300, .x_header= 82072u, .x_header_p0= 81944u, .w_header= 656593207448u, .w_header_p0= 81944u , .debug_nhwc_words=2400 },
{.n=1 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=8 , .h=8 , .w=1 , .ci=300 , .co=10 , .w_kw2=1 , .t=1 , .p=15 , .cm=20 , .cm_p0=20 , .xp_words=14 , .out_buffer_idx=-1, .add_buffer_idx=-1, .w_bpt=248 , .w_bpt_p0=248 , .x_bpt=148 , .x_bpt_p0=148 , .o_words=80 , .o_bytes=320 , .is_bias=1 , .is_flatten=0 , .b_offset=88 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=8 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=8 , .psh_shift=0 , .csw=1 , .cw=1 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=1 , .psw_shift=0 , .pool=POOL_NONE , .on=1 , .oh=8 , .ow=1 , .oc=10 , .x_header= 152u, .x_header_p0= 152u, .w_header= 652835029144u, .w_header_p0= 152u , .debug_nhwc_words=80 }
{.n=8 , .l=3 , .kw=11 , .coe=2 , .coe_tl=2 , .r_ll=2 , .h=18 , .w=18 , .ci=3 , .co=8 , .w_kw2=13 , .t=4 , .p=3 , .cm=1 , .cm_p0=1 , .xp_words=6048, .w_bpt=140 , .w_bpt_p0=140 , .x_bpt=3032 , .x_bpt_p0=3032 , .o_words=5376 , .o_bytes=2696 , .out_buffer_idx=0 , .add_out_buffer_idx=0 , .add_in_buffer_idx=-1, .is_bias=1 , .is_flatten=0 , .b_offset=0 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=0 , .ca_shift=12 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=2 , .ch=9 , .csh_shift=1 , .pkh=3 , .psh=2 , .ph=5 , .psh_shift=1 , .csw=1 , .cw=18 , .csw_shift=0 , .pkw=4 , .psw=3 , .pw=6 , .psw_shift=0 , .pool=POOL_AVG , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 17055749u, .x_header_p0= 17055749u, .w_header= 347372535813u, .w_header_p0= 17055749u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=6 , .t=1 , .p=1 , .cm=20 , .cm_p0=8 , .xp_words=672, .w_bpt=104 , .w_bpt_p0=104 , .x_bpt=2696 , .x_bpt_p0=2696 , .o_words=5376 , .o_bytes=2720 , .out_buffer_idx=1 , .add_out_buffer_idx=-1, .add_in_buffer_idx=-1, .is_bias=1 , .is_flatten=0 , .b_offset=8 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=12 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81976u, .x_header_p0= 81976u, .w_header= 244276346936u, .w_header_p0= 81976u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=7 , .coe=3 , .coe_tl=2 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=3 , .t=3 , .p=4 , .cm=2 , .cm_p0=2 , .xp_words=672, .w_bpt=176 , .w_bpt_p0=176 , .x_bpt=680 , .x_bpt_p0=680 , .o_words=5376 , .o_bytes=2704 , .out_buffer_idx=0 , .add_out_buffer_idx=-1, .add_in_buffer_idx=0 , .is_bias=0 , .is_flatten=0 , .b_offset=32 , .b_val_shift=0 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=3 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81931u, .x_header_p0= 81931u, .w_header= 450434777099u, .w_header_p0= 81931u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=5 , .coe=4 , .coe_tl=4 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=4 , .t=2 , .p=2 , .cm=4 , .cm_p0=4 , .xp_words=672, .w_bpt=248 , .w_bpt_p0=248 , .x_bpt=1352 , .x_bpt_p0=1352 , .o_words=5376 , .o_bytes=2704 , .out_buffer_idx=1 , .add_out_buffer_idx=-1, .add_in_buffer_idx=-1, .is_bias=1 , .is_flatten=0 , .b_offset=32 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81946u, .x_header_p0= 81946u, .w_header= 656593207322u, .w_header_p0= 81946u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=3 , .coe=8 , .coe_tl=8 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=24 , .w_kw2=5 , .t=3 , .p=2 , .cm=6 , .cm_p0=2 , .xp_words=672, .w_bpt=224 , .w_bpt_p0=80 , .x_bpt=2024 , .x_bpt_p0=680 , .o_words=16128, .o_bytes=8080 , .out_buffer_idx=0 , .add_out_buffer_idx=-1, .add_in_buffer_idx=-1, .is_bias=1 , .is_flatten=0 , .b_offset=40 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=0 , .ca_shift=12 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=24 , .x_header= 81961u, .x_header_p0= 81929u, .w_header= 587873730601u, .w_header_p0= 81929u , .debug_nhwc_words=5760 },
{.n=8 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=5 , .h=5 , .w=6 , .ci=24 , .co=10 , .w_kw2=6 , .t=1 , .p=2 , .cm=20 , .cm_p0=4 , .xp_words=672, .w_bpt=248 , .w_bpt_p0=56 , .x_bpt=6728 , .x_bpt_p0=1352 , .o_words=4200 , .o_bytes=2220 , .out_buffer_idx=1 , .add_out_buffer_idx=-1, .add_in_buffer_idx=-1, .is_bias=1 , .is_flatten=1 , .b_offset=64 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=1 , .oh=8 , .ow=1 , .oc=300, .x_header= 82072u, .x_header_p0= 81944u, .w_header= 656593207448u, .w_header_p0= 81944u , .debug_nhwc_words=2400 },
{.n=1 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=8 , .h=8 , .w=1 , .ci=300 , .co=10 , .w_kw2=1 , .t=1 , .p=15 , .cm=20 , .cm_p0=20 , .xp_words=14 , .w_bpt=248 , .w_bpt_p0=248 , .x_bpt=148 , .x_bpt_p0=148 , .o_words=80 , .o_bytes=320 , .out_buffer_idx=-1, .add_out_buffer_idx=-1, .add_in_buffer_idx=-1, .is_bias=1 , .is_flatten=0 , .b_offset=88 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=8 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=8 , .psh_shift=0 , .csw=1 , .cw=1 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=1 , .psw_shift=0 , .pool=POOL_NONE , .on=1 , .oh=8 , .ow=1 , .oc=10 , .x_header= 152u, .x_header_p0= 152u, .w_header= 652835029144u, .w_header_p0= 152u , .debug_nhwc_words=80 }
};

#define X_BITS_L2 2
Expand All @@ -16,15 +16,15 @@ Bundle_t bundles [N_BUNDLES] = {
#define PE_ROWS 8
#define PE_COLS 24

#define N_BUF 2
#define N_ADD_BUF 1
#define WB_BYTES 10048
#define W_BYTES 9824
#define X_BYTES 9096
#define O_WORDS 80
#define O_WORDS_MAX 16128
#define O_BYTES_MAX 8080
#define X_BYTES_ALL 30220
#define Y_BYTES 110600
#define NHWC_WORDS 20736
#define B_TYPE int16_t
#define B_WORDS 112
#define DATA_DIR "D:/dnn-engine/test/vectors"
Expand Down
40 changes: 24 additions & 16 deletions c/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
#endif

typedef const struct {
const int32_t n, l, kw, coe, coe_tl, r_ll, h, w, ci, co, w_kw2, t, p, cm, cm_p0, xp_words, out_buffer_idx, add_buffer_idx;
const int32_t n, l, kw, coe, coe_tl, r_ll, h, w, ci, co, w_kw2, t, p, cm, cm_p0, xp_words;
const int32_t w_bpt, w_bpt_p0, x_bpt, x_bpt_p0, o_words, o_bytes; // bytes per transfer
const int8_t out_buffer_idx, add_out_buffer_idx, add_in_buffer_idx;
const int8_t is_bias, is_pool, is_flatten;
const int32_t b_offset, b_val_shift, b_bias_shift;
const int8_t ca_nzero, ca_shift, ca_pl_scale, add_act_shift, pool_act_shift;
Expand All @@ -31,12 +32,13 @@ typedef enum {POOL_NONE, POOL_MAX, POOL_AVG} Pool_t;
typedef struct {
int8_t w [W_BYTES ];
B_TYPE b [B_WORDS ]; // keep next to w. weights are loaded to w_ptr
int8_t buffers [N_BUF][O_BYTES_MAX ];
int8_t x [X_BYTES_ALL ];
int32_t y [O_WORDS ];
int32_t nhwc [Y_BYTES/4 ];
int32_t nhwc [NHWC_WORDS ];
int8_t debug_tiled [O_WORDS_MAX ];
int32_t debug_nhwc [Y_BYTES/4 ];
int32_t debug_nhwc [NHWC_WORDS ];
int8_t out_buffers [2 ][O_BYTES_MAX ];
int8_t add_buffers [N_ADD_BUF ][NHWC_WORDS ];
} Memory_st;
Memory_st mem;

Expand All @@ -47,7 +49,7 @@ volatile char is_bundle_write_done = 1;

#define flatten_nhwc(in,ih,iw,ic, N,H,W,C, optional_debug_info,...)\
((in*H + ih)*W + iw)*C + ic;\
assert_printf (in, <, N, optional_debug_info,__VA_ARGS__); assert_printf (ih, <, H, optional_debug_info,__VA_ARGS__); assert_printf (iw, <, W, optional_debug_info,__VA_ARGS__); assert_printf (ic, <, C, optional_debug_info,__VA_ARGS__);
assert_printf (in, <, N, optional_debug_info,__VA_ARGS__); assert_printf (ih, <, H, optional_debug_info,__VA_ARGS__); assert_printf (iw, <, W, optional_debug_info,__VA_ARGS__); assert_printf (ic, <, C, optional_debug_info,__VA_ARGS__); assert_printf ((((in*H + ih)*W + iw)*C + ic), <, NHWC_WORDS, optional_debug_info,__VA_ARGS__);

#define max(x, y) ((x) > (y) ? (x) : (y))
#define min(x, y) ((x) < (y) ? (x) : (y))
Expand Down Expand Up @@ -88,14 +90,6 @@ static inline void write_x(int8_t val, int8_t *p_out_buffer, int32_t ib, int32_t

assert_printf (packed_index , <, bundles[ib].o_bytes, "write_x", WRITEX_DEBUG_INFO);

// // ------ RESIDUAL ADD ----
// if (bundles[ib].add_buffer_idx != -1){
// uint8_t add_byte = mem.buffers[bundles[ib].add_buffer_idx][packed_index];
// uint8_t add_byte_cleaned = X_POSITION_INVERTED_MASKS[packed_position] & add_byte;
// uint8_t add_byte_unpacked = (add_byte_cleaned >> (packed_position * X_BITS)) & X_BITS_MASK;
// int8_t add_val = add_byte_unpacked | ~X_BITS_MASK);
// }

uint8_t packed_val = ((uint8_t)val & X_BITS_MASK) << (packed_position * X_BITS);
uint8_t mem_val = p_out_buffer[packed_index];
uint8_t mem_val_cleaned = X_POSITION_INVERTED_MASKS[packed_position] & mem_val;
Expand Down Expand Up @@ -147,6 +141,10 @@ static inline void tile_write( int32_t out_val, int8_t *p_out_buffer, int32_t ib
int32_t iy_nhwc = flatten_nhwc(i_yn,i_yh,i_yw,i_yc, yn,yh,yw,yc,,);
mem.debug_nhwc[iy_nhwc] = out_val;

// Store for residual add
if (pb->add_out_buffer_idx != -1)
mem.add_buffers[pb->add_out_buffer_idx][iy_nhwc] = (int8_t)out_val;

if (ib == N_BUNDLES-1)
mem.y[iy_nhwc] = out_val; // Last bundle: save as NHWC
else {
Expand Down Expand Up @@ -174,7 +172,7 @@ extern EXT_C void load_y (uint8_t *p_done, uint8_t *pt_done_proc, const uint32_
static Bundle_t *pb = &bundles[0];
static int32_t it_bias=0;
static int32_t ib=0, ip=0, it=0, in=0, il=0, iw_kw2=0;
static int8_t *p_out_buffer = (int8_t*)&mem.buffers[0];
static int8_t *p_out_buffer = (int8_t*)&mem.out_buffers[0];
const int32_t *p_sram = (const int32_t *)p_sram_u32;

int32_t iy_nhwc;
Expand Down Expand Up @@ -281,6 +279,16 @@ extern EXT_C void load_y (uint8_t *p_done, uint8_t *pt_done_proc, const uint32_
// ------ SOFTMAX ------


// ------ RESIDUAL ADD ---

if (pb->add_in_buffer_idx != -1) {
iy_nhwc = flatten_nhwc(i_yn,i_yh,i_yw,i_yc, yn,yh,yw,yc, "Before add", DEBUG_INFO);// store as nhwc for pooling
out_val += mem.add_buffers[pb->add_in_buffer_idx][iy_nhwc];
out_val = shift_round(out_val, pb->add_act_shift);
out_val = clip(out_val, -(1<<(X_BITS-1)), (1<<(X_BITS-1))-1);
}


// ------ MAX/AVG POOL ---

if (pb->pool == POOL_NONE) {
Expand Down Expand Up @@ -396,7 +404,7 @@ extern EXT_C void load_y (uint8_t *p_done, uint8_t *pt_done_proc, const uint32_
}//new(ib):

pb = &bundles[ib];
p_out_buffer = (int8_t*)&mem.buffers[pb->out_buffer_idx];
p_out_buffer = (int8_t*)&mem.out_buffers[pb->out_buffer_idx];
if (ib != N_BUNDLES-1) write_x_header = 1; // Make write_x write new headers

}//new(ip):
Expand All @@ -413,7 +421,7 @@ extern EXT_C void load_x (uint8_t *p_done, uint8_t *bundle_read_done, uint64_t *

static int32_t ib=0, ip=0, it=0, offset_next=0;

int8_t *p_buffer_base = (ib==0) ? mem.x : mem.buffers[bundles[ib-1].out_buffer_idx];
int8_t *p_buffer_base = (ib==0) ? mem.x : mem.out_buffers[bundles[ib-1].out_buffer_idx];

*p_base_addr = (uint64_t)p_buffer_base + offset_next;
*p_bpt = ip == 0 ? bundles[ib].x_bpt_p0 : bundles[ib].x_bpt;
Expand Down
7 changes: 3 additions & 4 deletions test/py/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(self,
# Store reference to bundle object here, not just a idx number
self.prev_bundle = None
self.add_bundle = None
self.out_tensor_dest = []
self.buffer_idx = None
self.add_tensor_dest = []
self.add_out_buffer_idx = None

def extract_act(signature):
ilayer = QActivation(signature)
Expand Down Expand Up @@ -161,7 +161,6 @@ def call(self, x, x_1=None):
if hasattr(x, "bundle"):
self.prev_bundle = x.bundle
self.idx = self.prev_bundle.idx + 1
self.prev_bundle.out_tensor_dest += [self.idx]
else:
self.prev_bundle = None
self.idx = 0
Expand All @@ -175,7 +174,7 @@ def call(self, x, x_1=None):
if x_1 is not None:
if hasattr(x_1, "bundle"):
self.add['bundle'] = x_1.bundle
x_1.bundle.out_tensor_dest += [self.idx]
x_1.bundle.add_tensor_dest += [self.idx]
else:
self.add['bundle'] = None
x = Add()([x, x_1])
Expand Down
Loading

0 comments on commit a63a3b3

Please sign in to comment.