Skip to content

Commit

Permalink
Fix avg pool & act with shift=0
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Nov 11, 2023
1 parent 44e1dad commit d6df904
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 40 deletions.
18 changes: 9 additions & 9 deletions c/model.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#define N_BUNDLES 7
Bundle_t bundles [N_BUNDLES] = {
{.n=8 , .l=3 , .kw=11 , .coe=2 , .coe_tl=2 , .r_ll=2 , .h=18 , .w=18 , .ci=3 , .co=8 , .w_kw2=13 , .t=4 , .p=3 , .cm=1 , .cm_p0=1 , .xp_words=6048, .out_buffer_idx=0 , .w_bpt=140 , .w_bpt_p0=140 , .x_bpt=3032 , .x_bpt_p0=3032 , .o_words=5376 , .o_bytes=2696 , .is_bias=1 , .is_flatten=0 , .b_offset=0 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=0 , .ca_shift=12 , .ca_pl_scale=0 , .csh=2 , .ch=9 , .csh_shift=1 , .pkh=3 , .psh=2 , .ph=5 , .psh_shift=1 , .csw=1 , .cw=18 , .csw_shift=0 , .pkw=4 , .psw=3 , .pw=6 , .psw_shift=0 , .pool=POOL_MAX , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 17055749u, .x_header_p0= 17055749u, .w_header= 347372535813u, .w_header_p0= 17055749u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=6 , .t=1 , .p=1 , .cm=20 , .cm_p0=8 , .xp_words=672, .out_buffer_idx=1 , .w_bpt=104 , .w_bpt_p0=104 , .x_bpt=2696 , .x_bpt_p0=2696 , .o_words=5376 , .o_bytes=2720 , .is_bias=0 , .is_flatten=0 , .b_offset=8 , .b_val_shift=0 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=3 , .ca_pl_scale=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81976u, .x_header_p0= 81976u, .w_header= 244276346936u, .w_header_p0= 81976u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=7 , .coe=3 , .coe_tl=2 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=3 , .t=3 , .p=4 , .cm=2 , .cm_p0=2 , .xp_words=672, .out_buffer_idx=0 , .w_bpt=176 , .w_bpt_p0=176 , .x_bpt=680 , .x_bpt_p0=680 , .o_words=5376 , .o_bytes=2704 , .is_bias=1 , .is_flatten=0 , .b_offset=8 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=12 , .ca_pl_scale=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81931u, .x_header_p0= 81931u, .w_header= 450434777099u, .w_header_p0= 81931u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=5 , .coe=4 , .coe_tl=4 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=4 , .t=2 , .p=2 , .cm=4 , .cm_p0=4 , .xp_words=672, .out_buffer_idx=1 , .w_bpt=248 , .w_bpt_p0=248 , .x_bpt=1352 , .x_bpt_p0=1352 , .o_words=5376 , .o_bytes=2704 , .is_bias=0 , .is_flatten=0 , .b_offset=17 , .b_val_shift=0 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=6 , .ca_pl_scale=3 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81946u, .x_header_p0= 81946u, .w_header= 656593207322u, .w_header_p0= 81946u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=3 , .coe=8 , .coe_tl=8 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=24 , .w_kw2=5 , .t=3 , .p=2 , .cm=6 , .cm_p0=2 , .xp_words=672, .out_buffer_idx=0 , .w_bpt=224 , .w_bpt_p0=80 , .x_bpt=2024 , .x_bpt_p0=680 , .o_words=16128, .o_bytes=8080 , .is_bias=1 , .is_flatten=0 , .b_offset=17 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=0 , .ca_shift=12 , .ca_pl_scale=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=24 , .x_header= 81961u, .x_header_p0= 81929u, .w_header= 587873730601u, .w_header_p0= 81929u , .debug_nhwc_words=5760 },
{.n=8 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=5 , .h=5 , .w=6 , .ci=24 , .co=10 , .w_kw2=6 , .t=1 , .p=2 , .cm=20 , .cm_p0=4 , .xp_words=672, .out_buffer_idx=1 , .w_bpt=248 , .w_bpt_p0=56 , .x_bpt=6728 , .x_bpt_p0=1352 , .o_words=4200 , .o_bytes=2220 , .is_bias=0 , .is_flatten=1 , .b_offset=41 , .b_val_shift=0 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=6 , .ca_pl_scale=3 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=1 , .oh=8 , .ow=1 , .oc=300, .x_header= 82072u, .x_header_p0= 81944u, .w_header= 656593207448u, .w_header_p0= 81944u , .debug_nhwc_words=2400 },
{.n=1 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=8 , .h=8 , .w=1 , .ci=300 , .co=10 , .w_kw2=1 , .t=1 , .p=15 , .cm=20 , .cm_p0=20 , .xp_words=14 , .out_buffer_idx=-1, .w_bpt=248 , .w_bpt_p0=248 , .x_bpt=148 , .x_bpt_p0=148 , .o_words=80 , .o_bytes=320 , .is_bias=1 , .is_flatten=0 , .b_offset=41 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .csh=1 , .ch=8 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=8 , .psh_shift=0 , .csw=1 , .cw=1 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=1 , .psw_shift=0 , .pool=POOL_NONE , .on=1 , .oh=8 , .ow=1 , .oc=10 , .x_header= 152u, .x_header_p0= 152u, .w_header= 652835029144u, .w_header_p0= 152u , .debug_nhwc_words=80 }
{.n=8 , .l=3 , .kw=11 , .coe=2 , .coe_tl=2 , .r_ll=2 , .h=18 , .w=18 , .ci=3 , .co=8 , .w_kw2=13 , .t=4 , .p=3 , .cm=1 , .cm_p0=1 , .xp_words=6048, .out_buffer_idx=0 , .add_buffer_idx=-1, .w_bpt=140 , .w_bpt_p0=140 , .x_bpt=3032 , .x_bpt_p0=3032 , .o_words=5376 , .o_bytes=2696 , .is_bias=1 , .is_flatten=0 , .b_offset=0 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=0 , .ca_shift=12 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=2 , .ch=9 , .csh_shift=1 , .pkh=3 , .psh=2 , .ph=5 , .psh_shift=1 , .csw=1 , .cw=18 , .csw_shift=0 , .pkw=4 , .psw=3 , .pw=6 , .psw_shift=0 , .pool=POOL_AVG , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 17055749u, .x_header_p0= 17055749u, .w_header= 347372535813u, .w_header_p0= 17055749u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=6 , .t=1 , .p=1 , .cm=20 , .cm_p0=8 , .xp_words=672, .out_buffer_idx=1 , .add_buffer_idx=-1, .w_bpt=104 , .w_bpt_p0=104 , .x_bpt=2696 , .x_bpt_p0=2696 , .o_words=5376 , .o_bytes=2720 , .is_bias=1 , .is_flatten=0 , .b_offset=8 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=12 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81976u, .x_header_p0= 81976u, .w_header= 244276346936u, .w_header_p0= 81976u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=7 , .coe=3 , .coe_tl=2 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=3 , .t=3 , .p=4 , .cm=2 , .cm_p0=2 , .xp_words=672, .out_buffer_idx=0 , .add_buffer_idx=-1, .w_bpt=176 , .w_bpt_p0=176 , .x_bpt=680 , .x_bpt_p0=680 , .o_words=5376 , .o_bytes=2704 , .is_bias=0 , .is_flatten=0 , .b_offset=32 , .b_val_shift=0 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=3 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81931u, .x_header_p0= 81931u, .w_header= 450434777099u, .w_header_p0= 81931u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=5 , .coe=4 , .coe_tl=4 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=8 , .w_kw2=4 , .t=2 , .p=2 , .cm=4 , .cm_p0=4 , .xp_words=672, .out_buffer_idx=1 , .add_buffer_idx=-1, .w_bpt=248 , .w_bpt_p0=248 , .x_bpt=1352 , .x_bpt_p0=1352 , .o_words=5376 , .o_bytes=2704 , .is_bias=1 , .is_flatten=0 , .b_offset=32 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=8 , .x_header= 81946u, .x_header_p0= 81946u, .w_header= 656593207322u, .w_header_p0= 81946u , .debug_nhwc_words=1920 },
{.n=8 , .l=1 , .kw=3 , .coe=8 , .coe_tl=8 , .r_ll=5 , .h=5 , .w=6 , .ci=8 , .co=24 , .w_kw2=5 , .t=3 , .p=2 , .cm=6 , .cm_p0=2 , .xp_words=672, .out_buffer_idx=0 , .add_buffer_idx=-1, .w_bpt=224 , .w_bpt_p0=80 , .x_bpt=2024 , .x_bpt_p0=680 , .o_words=16128, .o_bytes=8080 , .is_bias=1 , .is_flatten=0 , .b_offset=40 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=0 , .ca_shift=12 , .ca_pl_scale=0 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=8 , .oh=5 , .ow=6 , .oc=24 , .x_header= 81961u, .x_header_p0= 81929u, .w_header= 587873730601u, .w_header_p0= 81929u , .debug_nhwc_words=5760 },
{.n=8 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=5 , .h=5 , .w=6 , .ci=24 , .co=10 , .w_kw2=6 , .t=1 , .p=2 , .cm=20 , .cm_p0=4 , .xp_words=672, .out_buffer_idx=1 , .add_buffer_idx=-1, .w_bpt=248 , .w_bpt_p0=56 , .x_bpt=6728 , .x_bpt_p0=1352 , .o_words=4200 , .o_bytes=2220 , .is_bias=1 , .is_flatten=1 , .b_offset=64 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=5 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=5 , .psh_shift=0 , .csw=1 , .cw=6 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=6 , .psw_shift=0 , .pool=POOL_NONE , .on=1 , .oh=8 , .ow=1 , .oc=300, .x_header= 82072u, .x_header_p0= 81944u, .w_header= 656593207448u, .w_header_p0= 81944u , .debug_nhwc_words=2400 },
{.n=1 , .l=1 , .kw=1 , .coe=24 , .coe_tl=0 , .r_ll=8 , .h=8 , .w=1 , .ci=300 , .co=10 , .w_kw2=1 , .t=1 , .p=15 , .cm=20 , .cm_p0=20 , .xp_words=14 , .out_buffer_idx=-1, .add_buffer_idx=-1, .w_bpt=248 , .w_bpt_p0=248 , .x_bpt=148 , .x_bpt_p0=148 , .o_words=80 , .o_bytes=320 , .is_bias=1 , .is_flatten=0 , .b_offset=88 , .b_val_shift=9 , .b_bias_shift=0 , .ca_nzero=1 , .ca_shift=15 , .ca_pl_scale=3 , .add_act_shift=0 , .pool_act_shift=0 , .csh=1 , .ch=8 , .csh_shift=0 , .pkh=1 , .psh=1 , .ph=8 , .psh_shift=0 , .csw=1 , .cw=1 , .csw_shift=0 , .pkw=1 , .psw=1 , .pw=1 , .psw_shift=0 , .pool=POOL_NONE , .on=1 , .oh=8 , .ow=1 , .oc=10 , .x_header= 152u, .x_header_p0= 152u, .w_header= 652835029144u, .w_header_p0= 152u , .debug_nhwc_words=80 }
};

#define X_BITS_L2 2
Expand All @@ -17,7 +17,7 @@ Bundle_t bundles [N_BUNDLES] = {
#define PE_COLS 24

#define N_BUF 2
#define WB_BYTES 9954
#define WB_BYTES 10048
#define W_BYTES 9824
#define X_BYTES 9096
#define O_WORDS 80
Expand All @@ -26,7 +26,7 @@ Bundle_t bundles [N_BUNDLES] = {
#define X_BYTES_ALL 30220
#define Y_BYTES 110600
#define B_TYPE int16_t
#define B_WORDS 65
#define B_WORDS 112
#define DATA_DIR "D:/dnn-engine/test/vectors"

static const uint8_t X_POSITION_INVERTED_MASKS [] = { 240, 15 };
25 changes: 19 additions & 6 deletions c/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
#endif

typedef const struct {
const int32_t n, l, kw, coe, coe_tl, r_ll, h, w, ci, co, w_kw2, t, p, cm, cm_p0, xp_words, out_buffer_idx;
const int32_t n, l, kw, coe, coe_tl, r_ll, h, w, ci, co, w_kw2, t, p, cm, cm_p0, xp_words, out_buffer_idx, add_buffer_idx;
const int32_t w_bpt, w_bpt_p0, x_bpt, x_bpt_p0, o_words, o_bytes; // bytes per transfer
const int8_t is_bias, is_pool, is_flatten;
const int32_t b_offset, b_val_shift, b_bias_shift;
const int8_t ca_nzero, ca_shift, ca_pl_scale;
const int8_t ca_nzero, ca_shift, ca_pl_scale, add_act_shift, pool_act_shift;
const int32_t csh, ch, csh_shift, pkh, psh, ph, psh_shift, csw, cw, csw_shift, pkw, psw, pw, psw_shift, pool, on, oh, ow, oc;
const uint64_t x_header, x_header_p0, w_header, w_header_p0; // 64 bits (at least)
const int32_t debug_nhwc_words;
Expand Down Expand Up @@ -52,7 +52,7 @@ volatile char is_bundle_write_done = 1;
#define max(x, y) ((x) > (y) ? (x) : (y))
#define min(x, y) ((x) < (y) ? (x) : (y))
#define clip(x, xmin, xmax) (((x) < (xmin)) ? (xmin) : ((x) > (xmax)) ? (xmax) : (x))
#define shift_round(n, s) (((n) + (1<<((s)-1)) - (~((n)>>(s))&1) ) >> s) // === np.around(n/2**s).astype(int32_t)
#define shift_round(n, s) (((n) + ((s)>0 ? (1<<((s)-1)) - (~((n)>>(s))&1) : 0)) >> s) // === np.around(n/2**s).astype(int32_t)
#define div_round(a, b) (((a)+((b)/2) - (~((b)|(a)/(b)) &1))/(b))


Expand Down Expand Up @@ -88,6 +88,14 @@ static inline void write_x(int8_t val, int8_t *p_out_buffer, int32_t ib, int32_t

assert_printf (packed_index , <, bundles[ib].o_bytes, "write_x", WRITEX_DEBUG_INFO);

// // ------ RESIDUAL ADD ----
// if (bundles[ib].add_buffer_idx != -1){
// uint8_t add_byte = mem.buffers[bundles[ib].add_buffer_idx][packed_index];
// uint8_t add_byte_cleaned = X_POSITION_INVERTED_MASKS[packed_position] & add_byte;
// uint8_t add_byte_unpacked = (add_byte_cleaned >> (packed_position * X_BITS)) & X_BITS_MASK;
// int8_t add_val = add_byte_unpacked | ~X_BITS_MASK);
// }

uint8_t packed_val = ((uint8_t)val & X_BITS_MASK) << (packed_position * X_BITS);
uint8_t mem_val = p_out_buffer[packed_index];
uint8_t mem_val_cleaned = X_POSITION_INVERTED_MASKS[packed_position] & mem_val;
Expand Down Expand Up @@ -322,10 +330,15 @@ extern EXT_C void load_y (uint8_t *p_done, uint8_t *pt_done_proc, const uint32_
result = pb->pool==POOL_MAX ? max(result, read_val) : (result + read_val);
}
}
int32_t count = (ph_end-ph_beg)*(pw_end-pw_beg);
result = pb->pool==POOL_MAX ? result : div_round(result, count);

// ------ POOL ACTIVATION ------
// ------ AVG POOL: Divide & Activation ------
if (pb->pool == POOL_AVG) {
int32_t count = (ph_end-ph_beg)*(pw_end-pw_beg);
result = div_round(result, count);
result = shift_round(result, pb->pool_act_shift);
result = clip(result, -(1<<(X_BITS-1)), (1<<(X_BITS-1))-1);
}

tile_write(result, p_out_buffer, ib, pb, i_yn, ixh, ixw, i_yc, yn, pb->ph, pb->pw, yc); // Write
}
}
Expand Down
35 changes: 19 additions & 16 deletions test/py/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def call(self, x, x_1=None):
if x_1 is not None:
if hasattr(x_1, "bundle"):
self.add['bundle'] = x_1.bundle
self.x_1.out_tensor_dest += [self.idx]
x_1.bundle.out_tensor_dest += [self.idx]
else:
self.add['bundle'] = None
x = Add()([x, x_1])
Expand Down Expand Up @@ -304,7 +304,8 @@ def add (p, p_frac, p_bits, q, q_frac, q_bits):
def shift_round(n,s):
'''Performs integer division with round-to-nearest-even.
Eq: np.around(n/2**s).astype(int)'''
return (n + (1<<(s-1)) - (~(n>>s)&1) ) >> s
half_b = 1<<(s-1) if s>0 else 0
return (n + half_b - (s>0)*(~(n>>s)&1) ) >> s

def div_round(n,d):
'''Performs integer division with round-to-nearest-even for d>0.
Expand All @@ -330,16 +331,16 @@ def apply_act(act_dict):

if self.add is not None:
a = self.add['bundle']
out_frac_add, out_bits_add = max(self.proc['frac'], a.out['frac']), max(self.proc['bits'], a.out['bits'])

a_arr_cast = a.out['int'] * 2** (out_frac_add - a.out['frac'])
out_arr_cast = self.proc['int'] * 2 **(out_frac_add - self.proc['frac'])

self.proc['int'] = out_arr_cast.astype(np.int64) + a_arr_cast.astype(np.int64)
self.proc['bits'], self.proc['frac'] = out_bits_add, out_frac_add
(self.proc['int'], self.proc['frac'], self.proc['bits']), (self.add_val_shift, self.add_a_shift) = add(
self.proc['int'] , self.proc['frac'], self.proc['bits'],
a.out ['int'].astype(int), a.out ['frac'], a.out ['bits']
)
assert self.proc['bits'] <= c.INT_BITS, f"After residual addition, resulting bits {self.proc['bits']} are more than bits for integer in CPU {c.INT_BITS}. Reduce bits or increase integer bits of bias to continue"
apply_act(self.add['act'])

assert np.all(self.proc['int'] == self.add['tensor'].numpy() * 2**self.proc['frac']), f"Add + act output of bundle {self.idx} is not a fixed point"
else:
self.add_val_shift, self.add_a_shift = 0, 0

if self.pool_layer:

Expand Down Expand Up @@ -368,7 +369,7 @@ def apply_act(act_dict):
q_st = max((PSW*(PXW-1)+PKW-YW)//2, 0)

for n in range(YN):
for c in range(YC):
for ic in range(YC):
for iyh in range(YH):
for iyw in range(YW):

Expand Down Expand Up @@ -404,23 +405,25 @@ def apply_act(act_dict):
for ipyw in range(pw_end, pw_beg,-1):

if self.pool['type']=='max':
result = max(result, in_arr[n,ipyh,ipyw,c])
result = max(result, in_arr[n,ipyh,ipyw,ic])
else:
result += in_arr[n,ipyh,ipyw,c]
result += in_arr[n,ipyh,ipyw,ic]

count = (ph_end-ph_beg)*(pw_end-pw_beg)
result = result if self.pool['type']=='max' else result/count
result = result if self.pool['type']=='max' else div_round(result, count)
''' Writing '''
out_arr[n,ixh,ixw,c] = result
out_arr[n,ixh,ixw,ic] = result

pw_beg += PSW # move pooling window by stride
pw_end = min(pw_end+PSW, YW-1)
ph_beg += PSH # move pooling window by stride
ph_end = min(ph_end+PSH, YH-1)

self.proc['int'] = out_arr
self.proc['bits'] += 4
# apply_act(self.pool['act'])
if self.pool['type'] == 'avg':
self.proc['bits'] += int(np.ceil(np.log2(PKH*PKW)))
assert self.proc['bits'] <= c.INT_BITS, f"When summing avg pool, resulting bits {self.proc['bits']} are more than bits for integer in CPU {c.INT_BITS}. Reduce bits or increase integer bits of bias to continue"
apply_act(self.pool['act'])
assert np.all(self.proc['int'] == self.pool['tensor'].numpy() * 2**self.proc['frac']), f"Pool + act output of bundle {self.idx} is not a fixed point"

if self.flatten:
Expand Down
Loading

0 comments on commit d6df904

Please sign in to comment.