Skip to content

Commit

Permalink
Refactor runtime, enable all layers on py_test
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Oct 31, 2023
1 parent 983a601 commit b3aa2d8
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 78 deletions.
26 changes: 13 additions & 13 deletions c/model.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#define N_BUNDLES 7
Bundle_t bundles [N_BUNDLES] = {
{.n=8, .l=3, .kw=11, .coe=2, .coe_tl=2, .r_ll=8, .h=24, .w=32, .ci=3, .co=8, .w_kw2=27, .t=4, .p=3, .cm=1, .cm_p0=1, .w_bpt=140, .w_bpt_p0=140, .x_bpt=9992, .x_bpt_p0=9992, .o_bytes=18304, .is_bias=1, .is_flatten=0, .b_offset=0, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=8, .ca_pl_scale=0, .csh=2, .ch=12, .csh_shift=1, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=3, .cw=11, .csw_shift=1, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=8, .x_header=414356454485524480, .x_header_p0=414356454485524480, .w_header=414602830989492224, .w_header_p0=414356454485524480 , .debug_nhwc_words=8448 },
{.n=8, .l=2, .kw=1, .coe=24, .coe_tl=0, .r_ll=4, .h=12, .w=11, .ci=8, .co=8, .w_kw2=11, .t=1, .p=1, .cm=20, .cm_p0=8, .w_bpt=104, .w_bpt_p0=104, .x_bpt=18312, .x_bpt_p0=18312, .o_bytes=18304, .is_bias=0, .is_flatten=0, .b_offset=8, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=3, .ca_pl_scale=0, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=8, .x_header=4071265058259206144, .x_header_p0=4071265058259206144, .w_header=4071511408993370112, .w_header_p0=4071265058259206144 , .debug_nhwc_words=8448 },
{.n=8, .l=2, .kw=7, .coe=3, .coe_tl=2, .r_ll=4, .h=12, .w=11, .ci=8, .co=8, .w_kw2=8, .t=3, .p=4, .cm=2, .cm_p0=2, .w_bpt=176, .w_bpt_p0=176, .x_bpt=4584, .x_bpt_p0=4584, .o_bytes=18304, .is_bias=1, .is_flatten=0, .b_offset=8, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=8, .x_header=828673326552449024, .x_header_p0=828673326552449024, .w_header=828919728826220544, .w_header_p0=828673326552449024 , .debug_nhwc_words=8448 },
{.n=8, .l=2, .kw=5, .coe=4, .coe_tl=4, .r_ll=4, .h=12, .w=11, .ci=8, .co=8, .w_kw2=9, .t=2, .p=2, .cm=4, .cm_p0=4, .w_bpt=248, .w_bpt_p0=248, .x_bpt=9160, .x_bpt_p0=9160, .o_bytes=18304, .is_bias=0, .is_flatten=0, .b_offset=17, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=6, .ca_pl_scale=3, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=8, .x_header=1909537237121368064, .x_header_p0=1909537237121368064, .w_header=1909783690934747136, .w_header_p0=1909537237121368064 , .debug_nhwc_words=8448 },
{.n=8, .l=2, .kw=3, .coe=8, .coe_tl=8, .r_ll=4, .h=12, .w=11, .ci=8, .co=24, .w_kw2=10, .t=3, .p=2, .cm=6, .cm_p0=2, .w_bpt=224, .w_bpt_p0=80, .x_bpt=13736, .x_bpt_p0=4584, .o_bytes=54912, .is_bias=1, .is_flatten=0, .b_offset=17, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=24, .x_header=2990401147690287104, .x_header_p0=684558138476593152, .w_header=2990647584323796992, .w_header_p0=684558138476593152 , .debug_nhwc_words=25344 },
{.n=8, .l=2, .kw=1, .coe=24, .coe_tl=0, .r_ll=4, .h=12, .w=11, .ci=24, .co=5, .w_kw2=11, .t=1, .p=2, .cm=20, .cm_p0=4, .w_bpt=248, .w_bpt_p0=56, .x_bpt=45768, .x_bpt_p0=9160, .o_bytes=8580, .is_bias=0, .is_flatten=1, .b_offset=41, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=6, .ca_pl_scale=3, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=1, .oh=8, .ow=1, .oc=660, .x_header=10988794085900288000, .x_header_p0=1765422049045512192, .w_header=10989040539713667072, .w_header_p0=1765422049045512192 , .debug_nhwc_words=5280 },
{.n=1, .l=1, .kw=1, .coe=24, .coe_tl=0, .r_ll=8, .h=8, .w=1, .ci=660, .co=10, .w_kw2=1, .t=1, .p=33, .cm=20, .cm_p0=20, .w_bpt=248, .w_bpt_p0=248, .x_bpt=268, .x_bpt_p0=268, .o_bytes=80, .is_bias=1, .is_flatten=0, .b_offset=41, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=11, .ca_pl_scale=3, .csh=1, .ch=8, .csh_shift=0, .pkh=1, .psh=1, .ph=8, .psh_shift=0, .csw=1, .cw=1, .csw_shift=0, .pkw=1, .psw=1, .pw=1, .psw_shift=0, .on=1, .oh=8, .ow=1, .oc=10, .x_header=10952754293765046272, .x_header_p0=10952754293765046272, .w_header=10952754456973803520, .w_header_p0=10952754293765046272 , .debug_nhwc_words=80 }
{.n=8, .l=3, .kw=11, .coe=2, .coe_tl=2, .r_ll=2, .h=18, .w=18, .ci=3, .co=8, .w_kw2=13, .t=4, .p=3, .cm=1, .cm_p0=1, .w_bpt=140, .w_bpt_p0=140, .x_bpt=5624, .x_bpt_p0=5624, .o_bytes=9984, .is_bias=1, .is_flatten=0, .b_offset=0, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=18, .csh_shift=0, .pkh=3, .psh=2, .ph=9, .psh_shift=0, .csw=1, .cw=18, .csw_shift=0, .pkw=4, .psw=3, .pw=6, .psw_shift=0, .p_type=POOL_MAX, .on=8, .oh=9, .ow=6, .oc=8, .x_header=378324358931677184, .x_header_p0=378324358931677184, .w_header=378570735435644928, .w_header_p0=378324358931677184 , .debug_nhwc_words=3456 },
{.n=8, .l=2, .kw=1, .coe=24, .coe_tl=0, .r_ll=1, .h=9, .w=6, .ci=8, .co=8, .w_kw2=6, .t=1, .p=1, .cm=20, .cm_p0=8, .w_bpt=104, .w_bpt_p0=104, .x_bpt=9992, .x_bpt_p0=9992, .o_bytes=9984, .is_bias=0, .is_flatten=0, .b_offset=8, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=3, .ca_pl_scale=0, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=8, .oh=9, .ow=6, .oc=8, .x_header=4053249560238096384, .x_header_p0=4053249560238096384, .w_header=4053495910972260352, .w_header_p0=4053249560238096384 , .debug_nhwc_words=3456 },
{.n=8, .l=2, .kw=7, .coe=3, .coe_tl=2, .r_ll=1, .h=9, .w=6, .ci=8, .co=8, .w_kw2=3, .t=3, .p=4, .cm=2, .cm_p0=2, .w_bpt=176, .w_bpt_p0=176, .x_bpt=2504, .x_bpt_p0=2504, .o_bytes=9984, .is_bias=1, .is_flatten=0, .b_offset=8, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=8, .oh=9, .ow=6, .oc=8, .x_header=810657828531339264, .x_header_p0=810657828531339264, .w_header=810904230805110784, .w_header_p0=810657828531339264 , .debug_nhwc_words=3456 },
{.n=8, .l=2, .kw=5, .coe=4, .coe_tl=4, .r_ll=1, .h=9, .w=6, .ci=8, .co=8, .w_kw2=4, .t=2, .p=2, .cm=4, .cm_p0=4, .w_bpt=248, .w_bpt_p0=248, .x_bpt=5000, .x_bpt_p0=5000, .o_bytes=9984, .is_bias=0, .is_flatten=0, .b_offset=17, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=6, .ca_pl_scale=3, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=8, .oh=9, .ow=6, .oc=8, .x_header=1891521739100258304, .x_header_p0=1891521739100258304, .w_header=1891768192913637376, .w_header_p0=1891521739100258304 , .debug_nhwc_words=3456 },
{.n=8, .l=2, .kw=3, .coe=8, .coe_tl=8, .r_ll=1, .h=9, .w=6, .ci=8, .co=24, .w_kw2=5, .t=3, .p=2, .cm=6, .cm_p0=2, .w_bpt=224, .w_bpt_p0=80, .x_bpt=7496, .x_bpt_p0=2504, .o_bytes=29952, .is_bias=1, .is_flatten=0, .b_offset=17, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=8, .oh=9, .ow=6, .oc=24, .x_header=2972385649669177344, .x_header_p0=666542640455483392, .w_header=2972632086302687232, .w_header_p0=666542640455483392 , .debug_nhwc_words=10368 },
{.n=8, .l=2, .kw=1, .coe=24, .coe_tl=0, .r_ll=1, .h=9, .w=6, .ci=24, .co=10, .w_kw2=6, .t=1, .p=2, .cm=20, .cm_p0=4, .w_bpt=248, .w_bpt_p0=56, .x_bpt=24968, .x_bpt_p0=5000, .o_bytes=7020, .is_bias=0, .is_flatten=1, .b_offset=41, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=6, .ca_pl_scale=3, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=1, .oh=8, .ow=1, .oc=540, .x_header=10970778587879178240, .x_header_p0=1747406551024402432, .w_header=10971025041692557312, .w_header_p0=1747406551024402432 , .debug_nhwc_words=4320 },
{.n=1, .l=1, .kw=1, .coe=24, .coe_tl=0, .r_ll=8, .h=8, .w=1, .ci=540, .co=10, .w_kw2=1, .t=1, .p=27, .cm=20, .cm_p0=20, .w_bpt=248, .w_bpt_p0=248, .x_bpt=268, .x_bpt_p0=268, .o_bytes=80, .is_bias=1, .is_flatten=0, .b_offset=41, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=11, .ca_pl_scale=3, .csh=1, .ch=8, .csh_shift=0, .pkh=1, .psh=1, .ph=8, .psh_shift=0, .csw=1, .cw=1, .csw_shift=0, .pkw=1, .psw=1, .pw=1, .psw_shift=0, .p_type=POOL_NONE, .on=1, .oh=8, .ow=1, .oc=10, .x_header=10952754293765046272, .x_header_p0=10952754293765046272, .w_header=10952754456973803520, .w_header_p0=10952754293765046272 , .debug_nhwc_words=80 }
};

#define X_BITS_L2 3
Expand All @@ -16,13 +16,13 @@ Bundle_t bundles [N_BUNDLES] = {
#define PE_ROWS 8
#define PE_COLS 24

#define WB_BYTES 14418
#define W_BYTES 14288
#define X_BYTES 29976
#define WB_BYTES 12930
#define W_BYTES 12800
#define X_BYTES 16872
#define O_WORDS 80
#define O_BYTES_MAX 54912
#define X_BYTES_ALL 167036
#define Y_BYTES 196616
#define O_BYTES_MAX 29952
#define X_BYTES_ALL 94084
#define Y_BYTES 110600
#define B_TYPE signed short
#define B_WORDS 65
#define DATA_DIR "D:/dnn-engine/test/vectors"
Expand Down
67 changes: 27 additions & 40 deletions c/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ static inline void write_x(signed char val, int ib, int ixp, int ixn, int ixl, i


static inline void tile_write( int out_val, int ib, Bundle_t *p_bundle, int i_yn, int i_yh, int i_yw, int i_yc, int yn, int yh, int yw, int yc ) {

// ------ FLATTEN ------
if (p_bundle->is_flatten) {
i_yc = (i_yh*yw + i_yw)*yc + i_yc; // (H*W*C) -> C
i_yw = 0; // W=1
i_yh = i_yn; // N -> H
i_yn = 0; // N=1

yc = yh*yw*yc;
yw = 1;
yh = yn;
yn = 1;
}

// Check
assert_printf ("", yn == p_bundle->on, ": yn");
assert_printf ("", yh == p_bundle->oh, ": yh");
Expand Down Expand Up @@ -169,7 +183,7 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c


// if out of bounds, early return
if (i_yh >= p_bundle->h || i_yc >= p_bundle->co) {
if (i_yh >= yh || i_yc >= yc) {
if (ip == p_bundle->p-1)
fprintf(fp_sum,"%d\n", 0); // Save summed output
goto PROCESS_AND_STORE_DONE;
Expand Down Expand Up @@ -222,23 +236,6 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c
// ------ SOFTMAX ------


// ------ FLATTEN ------
if (p_bundle->is_flatten) {
// Pool & flatten are not compatible with each other
assert_printf (DBG, p_bundle->p_type == POOL_NONE, ": p_bundle->p_type == POOL_NONE");

i_yc = (i_yh*yw + i_yw)*yc + i_yc; // (H*W*C) -> C
i_yw = 0; // W=1
i_yh = i_yn; // N -> H
i_yn = 0; // N=1

yc = yh*yw*yc;
yw = 1;
yh = yn;
yn = 1;
}


// ------ MAX/AVG POOL ---

if (p_bundle->p_type == POOL_NONE) {
Expand All @@ -254,9 +251,6 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c
iy_nhwc = ((i_yn*yh + i_yh)*yw + i_yw)*yc + i_yc; // store as nhwc for pooling
mem.nhwc[iy_nhwc] = out_val;

ph_end = i_yh; // iy(h,w) is the bottom-right of pooling window -> All values in pooling window have been computed
pw_end = i_yw;

div_ixh = div(i_yh+p_bundle->psh_shift-p_bundle->pkh+1, p_bundle->psh);
div_ixw = div(i_yw+p_bundle->psw_shift-p_bundle->pkw+1, p_bundle->psw);
ixh_beg = div_ixh.quot; // ix(hw) that corresponds to the pooling window
Expand All @@ -265,31 +259,27 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c
if (ixh_beg < 0 || ixw_beg < 0) // skip when target ix(h,w) < 0
goto PROCESS_AND_STORE_DONE;

if (div_ixh.rem != 0) // invalid ixh
if (i_yh==yh-1) //but last yh. start sweeping
ixh_beg += 1;
else // not last yh. skip
goto PROCESS_AND_STORE_DONE;

if (div_ixh.rem != 0) // invalid ixh
if (i_yh==yh-1) ixh_beg += 1; //but last yh. start sweeping
else goto PROCESS_AND_STORE_DONE; // not last yh. skip

if (div_ixw.rem != 0)
if (i_yw==yw-1)
ixw_beg += 1;
else
goto PROCESS_AND_STORE_DONE;
if (i_yw==yw-1) ixw_beg += 1;
else goto PROCESS_AND_STORE_DONE;

ph_end = i_yh; // iy(h,w) is the bottom-right of pooling window -> All values in pooling window have been computed
pw_end = i_yw;
ph_beg_const = max(p_bundle->psh*ixh_beg-p_bundle->psh_shift, 0)-1; // p(h,w)_beg is the index of top left corner of pooling window. If negative, set to zero
pw_beg_const = max(p_bundle->psw*ixw_beg-p_bundle->psw_shift, 0)-1;

xh_sweep = i_yh == yh-1 ? p_bundle->ph : ixh_beg+1; // ix(hw) is sweeped from ix(hw)_beg to x(h,w)_sweep. Normally sweep is 1.
xw_sweep = i_yw == yw-1 ? p_bundle->pw : ixw_beg+1; // But when iy(h,w) is at its edges, need to compute remaining ix(hw) pixels by sweeping

ph_beg = ph_beg_const;
for (int ixh = ixh_beg; ixh < xh_sweep; ixh++){
// Sweep the pooling window
for (int ixh = ixh_beg, ph_beg = ph_beg_const; ixh < xh_sweep; ixh++, ph_beg += p_bundle->psh) {
for (int ixw = ixw_beg, pw_beg = pw_beg_const; ixw < xw_sweep; ixw++, pw_beg += p_bundle->psw) {

pw_beg = pw_beg_const; // move the pooling window back to start of sweep
for (int ixw = ixw_beg; ixw < xw_sweep; ixw++){

// Traverse the pool window & perform pooling
// Traverse each pool window & perform pooling
int result = p_bundle->p_type == POOL_MAX ? INT_MIN : 0;
for (int ipyh = ph_end; ipyh > ph_beg; ipyh--){
for (int ipyw = pw_end; ipyw > pw_beg; ipyw--){
Expand All @@ -308,10 +298,7 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c

// ------ POOL ACTIVATION ------
tile_write(result, ib, p_bundle, i_yn, ixh, ixw, i_yc, yn, p_bundle->ph, p_bundle->pw, yc); // Write

pw_beg += p_bundle->psw; // move pooling window by stride
}
ph_beg += p_bundle->psh; // move pooling window by stride
}
yh = p_bundle->ph;
yw = p_bundle->pw;
Expand Down
14 changes: 7 additions & 7 deletions test/py/param_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ class Config:
def test_dnn_engine(COMPILE):
c = make_compile_params(COMPILE)

input_shape = (1,18,18,3) # (XN, XH, XW, CI)
input_shape = (8,18,18,3) # (XN, XH, XW, CI)
model_config = [
Config(11, 1, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)', pool_d={'type':'max', 'size':(3,4), 'strides':(2,3), 'padding':'same', 'act_str':f'quantized_bits({c.X_BITS},0,False,False,1)'}),
Config(11, 8, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)', pool_d={'type':'max', 'size':(3,4), 'strides':(2,3), 'padding':'same', 'act_str':f'quantized_bits({c.X_BITS},0,False,False,1)'}),
Config(1 , 8, False, f'quantized_bits({c.X_BITS},0,False,False,1)'),
# Config(7 , 8, True , f'quantized_bits({c.X_BITS},0,False,True,1)'),
# Config(5 , 8, False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)'),
# Config(3 , 24, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)'),
# Config(1 , 5 , False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', flatten=True),
# Config(1 , 10, True , f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', dense= True),
Config(7 , 8, True , f'quantized_bits({c.X_BITS},0,False,True,1)'),
Config(5 , 8, False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)'),
Config(3 , 24, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)'),
Config(1 , 10 , False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', flatten=True),
Config(1 , 10, True , f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', dense= True),
]

'''
Expand Down
37 changes: 19 additions & 18 deletions test/py/pooling_no_np.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,6 @@
"# pStride = [(2,2)]\n",
"# mode = ['same']\n",
"\n",
"def window_op(ph_beg_end, pw_beg_end, pool_type, y_arr, n, c):\n",
" ph_beg_const, ph_end_const = ph_beg_end\n",
" pw_beg_const, pw_end_const = pw_beg_end\n",
"\n",
" result = -math.inf if pool_type == 'max' else 0\n",
"\n",
" for ipyh in range(ph_end_const, ph_beg_const,-1):\n",
" for ipyw in range(pw_end_const, pw_beg_const,-1):\n",
" \n",
" if pool_type=='max':\n",
" result = max(result, y_arr[n][ipyh][ipyw][c])\n",
" else:\n",
" result += y_arr[n,ipyh,ipyw,c]\n",
"\n",
" count = (ph_end_const-ph_beg_const)*(pw_end_const-pw_beg_const)\n",
" return result if pool_type=='max' else result/count\n",
"\n",
"def myPooling2D(pool_type, y_arr, Pool_size, Stride, padding_type):\n",
" assert len(Pool_size)==2 and len(Stride)==2, f\"{len(Pool_size)}, {len(Stride)}\"\n",
" assert padding_type in {\"same\", \"valid\"}\n",
Expand Down Expand Up @@ -118,7 +101,25 @@
" for ixh in range(ixh_beg, xh_sweep):\n",
" pw_end, pw_beg = pw_end_const, pw_beg_const # move the pooling window back to start of sweep\n",
" for ixw in range(ixw_beg, xw_sweep):\n",
" x_arr[n,ixh,ixw,c] = window_op((ph_beg, ph_end), (pw_beg, pw_end), pool_type, y_arr, n, c)\n",
"\n",
"\n",
" '''\n",
" Pooling\n",
" '''\n",
" result = -math.inf if pool_type == 'max' else 0\n",
" for ipyh in range(ph_end, ph_beg,-1):\n",
" for ipyw in range(pw_end, pw_beg,-1):\n",
" \n",
" if pool_type=='max':\n",
" result = max(result, y_arr[n,ipyh,ipyw,c])\n",
" else:\n",
" result += y_arr[n,ipyh,ipyw,c]\n",
"\n",
" count = (ph_end-ph_beg)*(pw_end-pw_beg)\n",
" result = result if pool_type=='max' else result/count\n",
"\n",
"\n",
" x_arr[n,ixh,ixw,c] = result\n",
" pw_beg += PSW # move pooling window by stride\n",
" pw_end = min(pw_end+PSW, YW-1)\n",
" ph_beg += PSH # move pooling window by stride\n",
Expand Down

0 comments on commit b3aa2d8

Please sign in to comment.