From b3aa2d86bd78dc508589cb1c09e11e9f69ce3ea7 Mon Sep 17 00:00:00 2001 From: Aba Date: Tue, 31 Oct 2023 09:07:30 -0700 Subject: [PATCH] Refactor runtime, enable all layers on py_test --- c/model.h | 26 +++++++------- c/runtime.h | 67 +++++++++++++++---------------------- test/py/param_test.py | 14 ++++---- test/py/pooling_no_np.ipynb | 37 ++++++++++---------- 4 files changed, 66 insertions(+), 78 deletions(-) diff --git a/c/model.h b/c/model.h index 6a8af31..6cc1d45 100644 --- a/c/model.h +++ b/c/model.h @@ -1,12 +1,12 @@ #define N_BUNDLES 7 Bundle_t bundles [N_BUNDLES] = { - {.n=8, .l=3, .kw=11, .coe=2, .coe_tl=2, .r_ll=8, .h=24, .w=32, .ci=3, .co=8, .w_kw2=27, .t=4, .p=3, .cm=1, .cm_p0=1, .w_bpt=140, .w_bpt_p0=140, .x_bpt=9992, .x_bpt_p0=9992, .o_bytes=18304, .is_bias=1, .is_flatten=0, .b_offset=0, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=8, .ca_pl_scale=0, .csh=2, .ch=12, .csh_shift=1, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=3, .cw=11, .csw_shift=1, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=8, .x_header=414356454485524480, .x_header_p0=414356454485524480, .w_header=414602830989492224, .w_header_p0=414356454485524480 , .debug_nhwc_words=8448 }, - {.n=8, .l=2, .kw=1, .coe=24, .coe_tl=0, .r_ll=4, .h=12, .w=11, .ci=8, .co=8, .w_kw2=11, .t=1, .p=1, .cm=20, .cm_p0=8, .w_bpt=104, .w_bpt_p0=104, .x_bpt=18312, .x_bpt_p0=18312, .o_bytes=18304, .is_bias=0, .is_flatten=0, .b_offset=8, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=3, .ca_pl_scale=0, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=8, .x_header=4071265058259206144, .x_header_p0=4071265058259206144, .w_header=4071511408993370112, .w_header_p0=4071265058259206144 , .debug_nhwc_words=8448 }, - {.n=8, .l=2, .kw=7, .coe=3, .coe_tl=2, .r_ll=4, .h=12, .w=11, .ci=8, .co=8, .w_kw2=8, .t=3, .p=4, .cm=2, .cm_p0=2, .w_bpt=176, .w_bpt_p0=176, .x_bpt=4584, .x_bpt_p0=4584, .o_bytes=18304, .is_bias=1, .is_flatten=0, .b_offset=8, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=8, .x_header=828673326552449024, .x_header_p0=828673326552449024, .w_header=828919728826220544, .w_header_p0=828673326552449024 , .debug_nhwc_words=8448 }, - {.n=8, .l=2, .kw=5, .coe=4, .coe_tl=4, .r_ll=4, .h=12, .w=11, .ci=8, .co=8, .w_kw2=9, .t=2, .p=2, .cm=4, .cm_p0=4, .w_bpt=248, .w_bpt_p0=248, .x_bpt=9160, .x_bpt_p0=9160, .o_bytes=18304, .is_bias=0, .is_flatten=0, .b_offset=17, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=6, .ca_pl_scale=3, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=8, .x_header=1909537237121368064, .x_header_p0=1909537237121368064, .w_header=1909783690934747136, .w_header_p0=1909537237121368064 , .debug_nhwc_words=8448 }, - {.n=8, .l=2, .kw=3, .coe=8, .coe_tl=8, .r_ll=4, .h=12, .w=11, .ci=8, .co=24, .w_kw2=10, .t=3, .p=2, .cm=6, .cm_p0=2, .w_bpt=224, .w_bpt_p0=80, .x_bpt=13736, .x_bpt_p0=4584, .o_bytes=54912, .is_bias=1, .is_flatten=0, .b_offset=17, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=8, .oh=12, .ow=11, .oc=24, .x_header=2990401147690287104, .x_header_p0=684558138476593152, .w_header=2990647584323796992, .w_header_p0=684558138476593152 , .debug_nhwc_words=25344 }, - {.n=8, .l=2, .kw=1, .coe=24, .coe_tl=0, .r_ll=4, .h=12, .w=11, .ci=24, .co=5, .w_kw2=11, .t=1, .p=2, .cm=20, .cm_p0=4, .w_bpt=248, .w_bpt_p0=56, .x_bpt=45768, .x_bpt_p0=9160, .o_bytes=8580, .is_bias=0, .is_flatten=1, .b_offset=41, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=6, .ca_pl_scale=3, .csh=1, .ch=12, .csh_shift=0, .pkh=1, .psh=1, .ph=12, .psh_shift=0, .csw=1, .cw=11, .csw_shift=0, .pkw=1, .psw=1, .pw=11, .psw_shift=0, .on=1, .oh=8, .ow=1, .oc=660, .x_header=10988794085900288000, .x_header_p0=1765422049045512192, .w_header=10989040539713667072, .w_header_p0=1765422049045512192 , .debug_nhwc_words=5280 }, - {.n=1, .l=1, .kw=1, .coe=24, .coe_tl=0, .r_ll=8, .h=8, .w=1, .ci=660, .co=10, .w_kw2=1, .t=1, .p=33, .cm=20, .cm_p0=20, .w_bpt=248, .w_bpt_p0=248, .x_bpt=268, .x_bpt_p0=268, .o_bytes=80, .is_bias=1, .is_flatten=0, .b_offset=41, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=11, .ca_pl_scale=3, .csh=1, .ch=8, .csh_shift=0, .pkh=1, .psh=1, .ph=8, .psh_shift=0, .csw=1, .cw=1, .csw_shift=0, .pkw=1, .psw=1, .pw=1, .psw_shift=0, .on=1, .oh=8, .ow=1, .oc=10, .x_header=10952754293765046272, .x_header_p0=10952754293765046272, .w_header=10952754456973803520, .w_header_p0=10952754293765046272 , .debug_nhwc_words=80 } + {.n=8, .l=3, .kw=11, .coe=2, .coe_tl=2, .r_ll=2, .h=18, .w=18, .ci=3, .co=8, .w_kw2=13, .t=4, .p=3, .cm=1, .cm_p0=1, .w_bpt=140, .w_bpt_p0=140, .x_bpt=5624, .x_bpt_p0=5624, .o_bytes=9984, .is_bias=1, .is_flatten=0, .b_offset=0, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=18, .csh_shift=0, .pkh=3, .psh=2, .ph=9, .psh_shift=0, .csw=1, .cw=18, .csw_shift=0, .pkw=4, .psw=3, .pw=6, .psw_shift=0, .p_type=POOL_MAX, .on=8, .oh=9, .ow=6, .oc=8, .x_header=378324358931677184, .x_header_p0=378324358931677184, .w_header=378570735435644928, .w_header_p0=378324358931677184 , .debug_nhwc_words=3456 }, + {.n=8, .l=2, .kw=1, .coe=24, .coe_tl=0, .r_ll=1, .h=9, .w=6, .ci=8, .co=8, .w_kw2=6, .t=1, .p=1, .cm=20, .cm_p0=8, .w_bpt=104, .w_bpt_p0=104, .x_bpt=9992, .x_bpt_p0=9992, .o_bytes=9984, .is_bias=0, .is_flatten=0, .b_offset=8, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=3, .ca_pl_scale=0, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=8, .oh=9, .ow=6, .oc=8, .x_header=4053249560238096384, .x_header_p0=4053249560238096384, .w_header=4053495910972260352, .w_header_p0=4053249560238096384 , .debug_nhwc_words=3456 }, + {.n=8, .l=2, .kw=7, .coe=3, .coe_tl=2, .r_ll=1, .h=9, .w=6, .ci=8, .co=8, .w_kw2=3, .t=3, .p=4, .cm=2, .cm_p0=2, .w_bpt=176, .w_bpt_p0=176, .x_bpt=2504, .x_bpt_p0=2504, .o_bytes=9984, .is_bias=1, .is_flatten=0, .b_offset=8, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=8, .oh=9, .ow=6, .oc=8, .x_header=810657828531339264, .x_header_p0=810657828531339264, .w_header=810904230805110784, .w_header_p0=810657828531339264 , .debug_nhwc_words=3456 }, + {.n=8, .l=2, .kw=5, .coe=4, .coe_tl=4, .r_ll=1, .h=9, .w=6, .ci=8, .co=8, .w_kw2=4, .t=2, .p=2, .cm=4, .cm_p0=4, .w_bpt=248, .w_bpt_p0=248, .x_bpt=5000, .x_bpt_p0=5000, .o_bytes=9984, .is_bias=0, .is_flatten=0, .b_offset=17, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=6, .ca_pl_scale=3, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=8, .oh=9, .ow=6, .oc=8, .x_header=1891521739100258304, .x_header_p0=1891521739100258304, .w_header=1891768192913637376, .w_header_p0=1891521739100258304 , .debug_nhwc_words=3456 }, + {.n=8, .l=2, .kw=3, .coe=8, .coe_tl=8, .r_ll=1, .h=9, .w=6, .ci=8, .co=24, .w_kw2=5, .t=3, .p=2, .cm=6, .cm_p0=2, .w_bpt=224, .w_bpt_p0=80, .x_bpt=7496, .x_bpt_p0=2504, .o_bytes=29952, .is_bias=1, .is_flatten=0, .b_offset=17, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=8, .ca_pl_scale=0, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=8, .oh=9, .ow=6, .oc=24, .x_header=2972385649669177344, .x_header_p0=666542640455483392, .w_header=2972632086302687232, .w_header_p0=666542640455483392 , .debug_nhwc_words=10368 }, + {.n=8, .l=2, .kw=1, .coe=24, .coe_tl=0, .r_ll=1, .h=9, .w=6, .ci=24, .co=10, .w_kw2=6, .t=1, .p=2, .cm=20, .cm_p0=4, .w_bpt=248, .w_bpt_p0=56, .x_bpt=24968, .x_bpt_p0=5000, .o_bytes=7020, .is_bias=0, .is_flatten=1, .b_offset=41, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=6, .ca_pl_scale=3, .csh=1, .ch=9, .csh_shift=0, .pkh=1, .psh=1, .ph=9, .psh_shift=0, .csw=1, .cw=6, .csw_shift=0, .pkw=1, .psw=1, .pw=6, .psw_shift=0, .p_type=POOL_NONE, .on=1, .oh=8, .ow=1, .oc=540, .x_header=10970778587879178240, .x_header_p0=1747406551024402432, .w_header=10971025041692557312, .w_header_p0=1747406551024402432 , .debug_nhwc_words=4320 }, + {.n=1, .l=1, .kw=1, .coe=24, .coe_tl=0, .r_ll=8, .h=8, .w=1, .ci=540, .co=10, .w_kw2=1, .t=1, .p=27, .cm=20, .cm_p0=20, .w_bpt=248, .w_bpt_p0=248, .x_bpt=268, .x_bpt_p0=268, .o_bytes=80, .is_bias=1, .is_flatten=0, .b_offset=41, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=11, .ca_pl_scale=3, .csh=1, .ch=8, .csh_shift=0, .pkh=1, .psh=1, .ph=8, .psh_shift=0, .csw=1, .cw=1, .csw_shift=0, .pkw=1, .psw=1, .pw=1, .psw_shift=0, .p_type=POOL_NONE, .on=1, .oh=8, .ow=1, .oc=10, .x_header=10952754293765046272, .x_header_p0=10952754293765046272, .w_header=10952754456973803520, .w_header_p0=10952754293765046272 , .debug_nhwc_words=80 } }; #define X_BITS_L2 3 @@ -16,13 +16,13 @@ Bundle_t bundles [N_BUNDLES] = { #define PE_ROWS 8 #define PE_COLS 24 -#define WB_BYTES 14418 -#define W_BYTES 14288 -#define X_BYTES 29976 +#define WB_BYTES 12930 +#define W_BYTES 12800 +#define X_BYTES 16872 #define O_WORDS 80 -#define O_BYTES_MAX 54912 -#define X_BYTES_ALL 167036 -#define Y_BYTES 196616 +#define O_BYTES_MAX 29952 +#define X_BYTES_ALL 94084 +#define Y_BYTES 110600 #define B_TYPE signed short #define B_WORDS 65 #define DATA_DIR "D:/dnn-engine/test/vectors" diff --git a/c/runtime.h b/c/runtime.h index 6a81bea..9a05f20 100644 --- a/c/runtime.h +++ b/c/runtime.h @@ -69,6 +69,20 @@ static inline void write_x(signed char val, int ib, int ixp, int ixn, int ixl, i static inline void tile_write( int out_val, int ib, Bundle_t *p_bundle, int i_yn, int i_yh, int i_yw, int i_yc, int yn, int yh, int yw, int yc ) { + + // ------ FLATTEN ------ + if (p_bundle->is_flatten) { + i_yc = (i_yh*yw + i_yw)*yc + i_yc; // (H*W*C) -> C + i_yw = 0; // W=1 + i_yh = i_yn; // N -> H + i_yn = 0; // N=1 + + yc = yh*yw*yc; + yw = 1; + yh = yn; + yn = 1; + } + // Check assert_printf ("", yn == p_bundle->on, ": yn"); assert_printf ("", yh == p_bundle->oh, ": yh"); @@ -169,7 +183,7 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c // if out of bounds, early return - if (i_yh >= p_bundle->h || i_yc >= p_bundle->co) { + if (i_yh >= yh || i_yc >= yc) { if (ip == p_bundle->p-1) fprintf(fp_sum,"%d\n", 0); // Save summed output goto PROCESS_AND_STORE_DONE; @@ -222,23 +236,6 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c // ------ SOFTMAX ------ - // ------ FLATTEN ------ - if (p_bundle->is_flatten) { - // Pool & flatten are not compatible with each other - assert_printf (DBG, p_bundle->p_type == POOL_NONE, ": p_bundle->p_type == POOL_NONE"); - - i_yc = (i_yh*yw + i_yw)*yc + i_yc; // (H*W*C) -> C - i_yw = 0; // W=1 - i_yh = i_yn; // N -> H - i_yn = 0; // N=1 - - yc = yh*yw*yc; - yw = 1; - yh = yn; - yn = 1; - } - - // ------ MAX/AVG POOL --- if (p_bundle->p_type == POOL_NONE) { @@ -254,9 +251,6 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c iy_nhwc = ((i_yn*yh + i_yh)*yw + i_yw)*yc + i_yc; // store as nhwc for pooling mem.nhwc[iy_nhwc] = out_val; - ph_end = i_yh; // iy(h,w) is the bottom-right of pooling window -> All values in pooling window have been computed - pw_end = i_yw; - div_ixh = div(i_yh+p_bundle->psh_shift-p_bundle->pkh+1, p_bundle->psh); div_ixw = div(i_yw+p_bundle->psw_shift-p_bundle->pkw+1, p_bundle->psw); ixh_beg = div_ixh.quot; // ix(hw) that corresponds to the pooling window @@ -265,31 +259,27 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c if (ixh_beg < 0 || ixw_beg < 0) // skip when target ix(h,w) < 0 goto PROCESS_AND_STORE_DONE; - if (div_ixh.rem != 0) // invalid ixh - if (i_yh==yh-1) //but last yh. start sweeping - ixh_beg += 1; - else // not last yh. skip - goto PROCESS_AND_STORE_DONE; - + if (div_ixh.rem != 0) // invalid ixh + if (i_yh==yh-1) ixh_beg += 1; //but last yh. start sweeping + else goto PROCESS_AND_STORE_DONE; // not last yh. skip + if (div_ixw.rem != 0) - if (i_yw==yw-1) - ixw_beg += 1; - else - goto PROCESS_AND_STORE_DONE; + if (i_yw==yw-1) ixw_beg += 1; + else goto PROCESS_AND_STORE_DONE; + ph_end = i_yh; // iy(h,w) is the bottom-right of pooling window -> All values in pooling window have been computed + pw_end = i_yw; ph_beg_const = max(p_bundle->psh*ixh_beg-p_bundle->psh_shift, 0)-1; // p(h,w)_beg is the index of top left corner of pooling window. If negative, set to zero pw_beg_const = max(p_bundle->psw*ixw_beg-p_bundle->psw_shift, 0)-1; xh_sweep = i_yh == yh-1 ? p_bundle->ph : ixh_beg+1; // ix(hw) is sweeped from ix(hw)_beg to x(h,w)_sweep. Normally sweep is 1. xw_sweep = i_yw == yw-1 ? p_bundle->pw : ixw_beg+1; // But when iy(h,w) is at its edges, need to compute remaining ix(hw) pixels by sweeping - ph_beg = ph_beg_const; - for (int ixh = ixh_beg; ixh < xh_sweep; ixh++){ + // Sweep the pooling window + for (int ixh = ixh_beg, ph_beg = ph_beg_const; ixh < xh_sweep; ixh++, ph_beg += p_bundle->psh) { + for (int ixw = ixw_beg, pw_beg = pw_beg_const; ixw < xw_sweep; ixw++, pw_beg += p_bundle->psw) { - pw_beg = pw_beg_const; // move the pooling window back to start of sweep - for (int ixw = ixw_beg; ixw < xw_sweep; ixw++){ - - // Traverse the pool window & perform pooling + // Traverse each pool window & perform pooling int result = p_bundle->p_type == POOL_MAX ? INT_MIN : 0; for (int ipyh = ph_end; ipyh > ph_beg; ipyh--){ for (int ipyw = pw_end; ipyw > pw_beg; ipyw--){ @@ -308,10 +298,7 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c // ------ POOL ACTIVATION ------ tile_write(result, ib, p_bundle, i_yn, ixh, ixw, i_yc, yn, p_bundle->ph, p_bundle->pw, yc); // Write - - pw_beg += p_bundle->psw; // move pooling window by stride } - ph_beg += p_bundle->psh; // move pooling window by stride } yh = p_bundle->ph; yw = p_bundle->pw; diff --git a/test/py/param_test.py b/test/py/param_test.py index 70a10de..7c5d257 100644 --- a/test/py/param_test.py +++ b/test/py/param_test.py @@ -191,15 +191,15 @@ class Config: def test_dnn_engine(COMPILE): c = make_compile_params(COMPILE) - input_shape = (1,18,18,3) # (XN, XH, XW, CI) + input_shape = (8,18,18,3) # (XN, XH, XW, CI) model_config = [ - Config(11, 1, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)', pool_d={'type':'max', 'size':(3,4), 'strides':(2,3), 'padding':'same', 'act_str':f'quantized_bits({c.X_BITS},0,False,False,1)'}), + Config(11, 8, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)', pool_d={'type':'max', 'size':(3,4), 'strides':(2,3), 'padding':'same', 'act_str':f'quantized_bits({c.X_BITS},0,False,False,1)'}), Config(1 , 8, False, f'quantized_bits({c.X_BITS},0,False,False,1)'), - # Config(7 , 8, True , f'quantized_bits({c.X_BITS},0,False,True,1)'), - # Config(5 , 8, False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)'), - # Config(3 , 24, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)'), - # Config(1 , 5 , False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', flatten=True), - # Config(1 , 10, True , f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', dense= True), + Config(7 , 8, True , f'quantized_bits({c.X_BITS},0,False,True,1)'), + Config(5 , 8, False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)'), + Config(3 , 24, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)'), + Config(1 , 10 , False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', flatten=True), + Config(1 , 10, True , f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', dense= True), ] ''' diff --git a/test/py/pooling_no_np.ipynb b/test/py/pooling_no_np.ipynb index 052b970..a9ab83e 100644 --- a/test/py/pooling_no_np.ipynb +++ b/test/py/pooling_no_np.ipynb @@ -46,23 +46,6 @@ "# pStride = [(2,2)]\n", "# mode = ['same']\n", "\n", - "def window_op(ph_beg_end, pw_beg_end, pool_type, y_arr, n, c):\n", - " ph_beg_const, ph_end_const = ph_beg_end\n", - " pw_beg_const, pw_end_const = pw_beg_end\n", - "\n", - " result = -math.inf if pool_type == 'max' else 0\n", - "\n", - " for ipyh in range(ph_end_const, ph_beg_const,-1):\n", - " for ipyw in range(pw_end_const, pw_beg_const,-1):\n", - " \n", - " if pool_type=='max':\n", - " result = max(result, y_arr[n][ipyh][ipyw][c])\n", - " else:\n", - " result += y_arr[n,ipyh,ipyw,c]\n", - "\n", - " count = (ph_end_const-ph_beg_const)*(pw_end_const-pw_beg_const)\n", - " return result if pool_type=='max' else result/count\n", - "\n", "def myPooling2D(pool_type, y_arr, Pool_size, Stride, padding_type):\n", " assert len(Pool_size)==2 and len(Stride)==2, f\"{len(Pool_size)}, {len(Stride)}\"\n", " assert padding_type in {\"same\", \"valid\"}\n", @@ -118,7 +101,25 @@ " for ixh in range(ixh_beg, xh_sweep):\n", " pw_end, pw_beg = pw_end_const, pw_beg_const # move the pooling window back to start of sweep\n", " for ixw in range(ixw_beg, xw_sweep):\n", - " x_arr[n,ixh,ixw,c] = window_op((ph_beg, ph_end), (pw_beg, pw_end), pool_type, y_arr, n, c)\n", + "\n", + "\n", + " '''\n", + " Pooling\n", + " '''\n", + " result = -math.inf if pool_type == 'max' else 0\n", + " for ipyh in range(ph_end, ph_beg,-1):\n", + " for ipyw in range(pw_end, pw_beg,-1):\n", + " \n", + " if pool_type=='max':\n", + " result = max(result, y_arr[n,ipyh,ipyw,c])\n", + " else:\n", + " result += y_arr[n,ipyh,ipyw,c]\n", + "\n", + " count = (ph_end-ph_beg)*(pw_end-pw_beg)\n", + " result = result if pool_type=='max' else result/count\n", + "\n", + "\n", + " x_arr[n,ixh,ixw,c] = result\n", " pw_beg += PSW # move pooling window by stride\n", " pw_end = min(pw_end+PSW, YW-1)\n", " ph_beg += PSH # move pooling window by stride\n",