diff --git a/c/runtime.h b/c/runtime.h index 003fb1d..aab7916 100644 --- a/c/runtime.h +++ b/c/runtime.h @@ -36,16 +36,22 @@ typedef struct { } Memory_st; Memory_st mem; + +#define assert_printf(v1, op, v2, optional_debug_info,...) ((v1 op v2) || (printf("ASSERT FAILED: \n CONDITION: "), printf("( " #v1 " " #op " " #v2 " )"), printf(", VALUES: ( %d %s %d ), ", v1, #op, v2), printf("DEBUG_INFO: " optional_debug_info), printf(" " __VA_ARGS__), printf("\n\n"), assert(v1 op v2), 0)) + +#define flatten_nhwc(in,ih,iw,ic, N,H,W,C, optional_debug_info,...)\ + ((in*H + ih)*W + iw)*C + ic;\ + assert_printf (in, <, N, optional_debug_info,__VA_ARGS__); assert_printf (ih, <, H, optional_debug_info,__VA_ARGS__); assert_printf (iw, <, W, optional_debug_info,__VA_ARGS__); assert_printf (ic, <, C, optional_debug_info,__VA_ARGS__); + #define max(x, y) ((x) > (y) ? (x) : (y)) #define min(x, y) ((x) < (y) ? (x) : (y)) #define clip(x, xmin, xmax) (((x) < (xmin)) ? (xmin) : ((x) > (xmax)) ? (xmax) : (x)) #define shift_round(n, s) (((n) + (1<<((s)-1)) - (~((n)>>(s))&1) ) >> s) // === np.around(n/2**s).astype(int) #define div_round(a, b) (((a)+((b)/2) - (~((b)|(a)/(b)) &1))/(b)) -#define assert_printf(debug_info, condition,...) ((condition) || (printf(#condition), printf(__VA_ARGS__), printf(debug_info), assert(condition), 0)) static inline int quant_lrelu(int x, signed char nzero, signed char shift, signed char pl_scale){ - x = x < 0 ? x*nzero : x << pl_scale; + x = x < 0 ? (nzero ? x: 0) : x << pl_scale; // Conditional, targeting ARM x = shift_round(x, shift); x = clip(x, -(1<<(X_BITS-pl_scale-1)), (1<<(X_BITS-1))-1); return x; @@ -54,17 +60,16 @@ static inline int quant_lrelu(int x, signed char nzero, signed char shift, signe static inline void write_x(signed char val, int ib, int ixp, int ixn, int ixl, int ixw, int ixcm, int ixr, Bundle_t *pb_out, int xcm ){ -#define DBG "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm - assert_printf(DBG, ixr < PE_ROWS+X_PAD, "ixr < PE_ROWS+X_PAD"); - assert_printf(DBG, ixcm < xcm , "ixcm < xcm "); - assert_printf(DBG, ixw < pb_out->w , "ixw < pb_out->w "); - assert_printf(DBG, ixl < pb_out->l , "ixl < pb_out->l "); - assert_printf(DBG, ixn < pb_out->n , "ixn < pb_out->n "); - assert_printf(DBG, ixp < pb_out->p , "ixp < pb_out->p "); - - int p_offset = (ixp == 0) ? 0 : (pb_out->cm_p0 + (ixp-1)*pb_out->cm) *pb_out->n*pb_out->l*pb_out->w*(PE_ROWS+X_PAD); - int flat_index_n2r = (((ixn*pb_out->l + ixl)*pb_out->w + ixw)*xcm + ixcm)*(PE_ROWS+X_PAD) + ixr; // multidim_index -> flat_index [n,l,w,cm,r] - mem.nx[p_offset + flat_index_n2r] = val; + assert_printf (ixr , <, PE_ROWS+X_PAD, "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm); + assert_printf (ixcm, <, xcm , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm); + assert_printf (ixw , <, pb_out->w , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm); + assert_printf (ixl , <, pb_out->l , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm); + assert_printf (ixn , <, pb_out->n , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm); + assert_printf (ixp , <, pb_out->p , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm); + + int p_offset = (ixp == 0) ? 0 : (pb_out->cm_p0 + (ixp-1)*pb_out->cm) *pb_out->n*pb_out->l*pb_out->w*(PE_ROWS+X_PAD); + int flat_index_n2r = (((ixn*pb_out->l + ixl)*pb_out->w + ixw)*xcm + ixcm)*(PE_ROWS+X_PAD) + ixr; // multidim_index -> flat_index [n,l,w,cm,r] + mem.nx[p_offset + flat_index_n2r] = val; } @@ -84,10 +89,10 @@ static inline void tile_write( int out_val, int ib, Bundle_t *pb, int i_yn, int } // Check - assert_printf ("", yn == pb->on, ": yn"); - assert_printf ("", yh == pb->oh, ": yh"); - assert_printf ("", yw == pb->ow, ": yw"); - assert_printf ("", yc == pb->oc, ": yc"); + assert_printf (yn, ==, pb->on,,); + assert_printf (yh, ==, pb->oh,,); + assert_printf (yw, ==, pb->ow,,); + assert_printf (yc, ==, pb->oc,,); // ------ TILING: Calculate X coordinates ------ // y [n,h,w,c] -> x[p, n, l, w,cmp, r+pad] @@ -107,17 +112,12 @@ static inline void tile_write( int out_val, int ib, Bundle_t *pb, int i_yn, int // ------ STORE ------ - int iy_nhwc = ((i_yn*yh + i_yh)*yw + i_yw)*yc + i_yc; + int iy_nhwc = flatten_nhwc(i_yn,i_yh,i_yw,i_yc, yn,yh,yw,yc,,); mem.debug_nhwc[iy_nhwc] = out_val; - if (ib == N_BUNDLES-1) { - // Last bundle: save as NHWC - assert_printf ("", i_yn < yn, ": i_yn < yn"); - assert_printf ("", i_yh < yh, ": i_yh < yh"); - assert_printf ("", i_yw < yw, ": i_yw < yw"); - assert_printf ("", i_yc < yc, ": i_yc < yc"); - mem.y[iy_nhwc] = out_val; - } else { + if (ib == N_BUNDLES-1) + mem.y[iy_nhwc] = out_val; // Last bundle: save as NHWC + else { // Other bundles: pad & save as tiled int yr_sweep = i_yh==yh-1 ? PE_ROWS : i_yr + 1; @@ -158,14 +158,14 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c //New iw_kw2: int w_last = iw_kw2 == pb->w_kw2-1 ? pb->kw/2+1 : 1; int sram_addr=0; - for (int icoe=0; icoecoe; icoe++) { + for (int icoe=0; icoe < pb->coe; icoe++) { int i_bias = it_bias + icoe; for (int iw_last=0; iw_lastp == 1) { // only p : proceed with value } else if (ip == pb->p-1) {// last p : read, add, proceed @@ -231,7 +230,6 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c out_val = quant_lrelu(out_val, pb->ca_nzero, pb->ca_shift, pb->ca_pl_scale); - // ------ SOFTMAX ------ @@ -242,12 +240,7 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c goto PROCESS_AND_STORE_DONE; } - assert_printf ("write_temp", i_yn < yn, ": i_yn < yn"); - assert_printf ("write_temp", i_yh < yh, ": i_yh < yh"); - assert_printf ("write_temp", i_yw < yw, ": i_yw < yw"); - assert_printf ("write_temp", i_yc < yc, ": i_yc < yc"); - - iy_nhwc = ((i_yn*yh + i_yh)*yw + i_yw)*yc + i_yc; // store as nhwc for pooling + iy_nhwc = flatten_nhwc(i_yn,i_yh,i_yw,i_yc, yn,yh,yw,yc, "Before maxpool", DEBUG_INFO);// store as nhwc for pooling mem.nhwc[iy_nhwc] = out_val; div_ixh = div(i_yh+pb->psh_shift-pb->pkh+1, pb->psh); @@ -259,7 +252,7 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c goto PROCESS_AND_STORE_DONE; // Pool Striding - if (div_ixh.rem != 0) // invalid ixh + if (div_ixh.rem != 0) // invalid ixh if (i_yh==yh-1) ixh_beg += 1; //but last yh. start sweeping else goto PROCESS_AND_STORE_DONE; // not last yh. skip @@ -284,12 +277,8 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c for (int ipyh = ph_end; ipyh > ph_beg; ipyh--){ for (int ipyw = pw_end; ipyw > pw_beg; ipyw--){ - assert_printf ("read", i_yn < yn, ": i_yn < yn"); - assert_printf ("read", ipyh < yh, ": ipyh < yh"); - assert_printf ("read", ipyw < yw, ": ipyw < yw"); - assert_printf ("read", i_yc < yc, ": i_yc < yc"); - - int read_val = mem.nhwc[((i_yn*yh + ipyh)*yw + ipyw)*yc + i_yc]; + int read_idx = flatten_nhwc(i_yn, ipyh, ipyw, i_yc, yn, yh, yw, yc, "Inside pool window", DEBUG_INFO); + int read_val = mem.nhwc[read_idx]; result = pb->pool==POOL_MAX ? max(result, read_val) : (result + read_val); } }