Skip to content

Commit

Permalink
Simplify assert macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Oct 31, 2023
1 parent 1d5a15d commit 14ac970
Showing 1 changed file with 33 additions and 44 deletions.
77 changes: 33 additions & 44 deletions c/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,22 @@ typedef struct {
} Memory_st;
Memory_st mem;


#define assert_printf(v1, op, v2, optional_debug_info,...) ((v1 op v2) || (printf("ASSERT FAILED: \n CONDITION: "), printf("( " #v1 " " #op " " #v2 " )"), printf(", VALUES: ( %d %s %d ), ", v1, #op, v2), printf("DEBUG_INFO: " optional_debug_info), printf(" " __VA_ARGS__), printf("\n\n"), assert(v1 op v2), 0))

#define flatten_nhwc(in,ih,iw,ic, N,H,W,C, optional_debug_info,...)\
((in*H + ih)*W + iw)*C + ic;\
assert_printf (in, <, N, optional_debug_info,__VA_ARGS__); assert_printf (ih, <, H, optional_debug_info,__VA_ARGS__); assert_printf (iw, <, W, optional_debug_info,__VA_ARGS__); assert_printf (ic, <, C, optional_debug_info,__VA_ARGS__);

#define max(x, y) ((x) > (y) ? (x) : (y))
#define min(x, y) ((x) < (y) ? (x) : (y))
#define clip(x, xmin, xmax) (((x) < (xmin)) ? (xmin) : ((x) > (xmax)) ? (xmax) : (x))
#define shift_round(n, s) (((n) + (1<<((s)-1)) - (~((n)>>(s))&1) ) >> s) // === np.around(n/2**s).astype(int)
#define div_round(a, b) (((a)+((b)/2) - (~((b)|(a)/(b)) &1))/(b))

#define assert_printf(debug_info, condition,...) ((condition) || (printf(#condition), printf(__VA_ARGS__), printf(debug_info), assert(condition), 0))

static inline int quant_lrelu(int x, signed char nzero, signed char shift, signed char pl_scale){
x = x < 0 ? x*nzero : x << pl_scale;
x = x < 0 ? (nzero ? x: 0) : x << pl_scale; // Conditional, targeting ARM
x = shift_round(x, shift);
x = clip(x, -(1<<(X_BITS-pl_scale-1)), (1<<(X_BITS-1))-1);
return x;
Expand All @@ -54,17 +60,16 @@ static inline int quant_lrelu(int x, signed char nzero, signed char shift, signe

static inline void write_x(signed char val, int ib, int ixp, int ixn, int ixl, int ixw, int ixcm, int ixr, Bundle_t *pb_out, int xcm ){

#define DBG "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm
assert_printf(DBG, ixr < PE_ROWS+X_PAD, "ixr < PE_ROWS+X_PAD");
assert_printf(DBG, ixcm < xcm , "ixcm < xcm ");
assert_printf(DBG, ixw < pb_out->w , "ixw < pb_out->w ");
assert_printf(DBG, ixl < pb_out->l , "ixl < pb_out->l ");
assert_printf(DBG, ixn < pb_out->n , "ixn < pb_out->n ");
assert_printf(DBG, ixp < pb_out->p , "ixp < pb_out->p ");

int p_offset = (ixp == 0) ? 0 : (pb_out->cm_p0 + (ixp-1)*pb_out->cm) *pb_out->n*pb_out->l*pb_out->w*(PE_ROWS+X_PAD);
int flat_index_n2r = (((ixn*pb_out->l + ixl)*pb_out->w + ixw)*xcm + ixcm)*(PE_ROWS+X_PAD) + ixr; // multidim_index -> flat_index [n,l,w,cm,r]
mem.nx[p_offset + flat_index_n2r] = val;
assert_printf (ixr , <, PE_ROWS+X_PAD, "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm);
assert_printf (ixcm, <, xcm , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm);
assert_printf (ixw , <, pb_out->w , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm);
assert_printf (ixl , <, pb_out->l , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm);
assert_printf (ixn , <, pb_out->n , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm);
assert_printf (ixp , <, pb_out->p , "write_x", "--- ib:%d ixp:%d ixn:%d ixl:%d ixw:%d ixcm:%d ixr:%d xcm :%d \n",ib,ixp,ixn,ixl,ixw,ixcm,ixr,xcm);

int p_offset = (ixp == 0) ? 0 : (pb_out->cm_p0 + (ixp-1)*pb_out->cm) *pb_out->n*pb_out->l*pb_out->w*(PE_ROWS+X_PAD);
int flat_index_n2r = (((ixn*pb_out->l + ixl)*pb_out->w + ixw)*xcm + ixcm)*(PE_ROWS+X_PAD) + ixr; // multidim_index -> flat_index [n,l,w,cm,r]
mem.nx[p_offset + flat_index_n2r] = val;
}


Expand All @@ -84,10 +89,10 @@ static inline void tile_write( int out_val, int ib, Bundle_t *pb, int i_yn, int
}

// Check
assert_printf ("", yn == pb->on, ": yn");
assert_printf ("", yh == pb->oh, ": yh");
assert_printf ("", yw == pb->ow, ": yw");
assert_printf ("", yc == pb->oc, ": yc");
assert_printf (yn, ==, pb->on,,);
assert_printf (yh, ==, pb->oh,,);
assert_printf (yw, ==, pb->ow,,);
assert_printf (yc, ==, pb->oc,,);

// ------ TILING: Calculate X coordinates ------
// y [n,h,w,c] -> x[p, n, l, w,cmp, r+pad]
Expand All @@ -107,17 +112,12 @@ static inline void tile_write( int out_val, int ib, Bundle_t *pb, int i_yn, int

// ------ STORE ------

int iy_nhwc = ((i_yn*yh + i_yh)*yw + i_yw)*yc + i_yc;
int iy_nhwc = flatten_nhwc(i_yn,i_yh,i_yw,i_yc, yn,yh,yw,yc,,);
mem.debug_nhwc[iy_nhwc] = out_val;

if (ib == N_BUNDLES-1) {
// Last bundle: save as NHWC
assert_printf ("", i_yn < yn, ": i_yn < yn");
assert_printf ("", i_yh < yh, ": i_yh < yh");
assert_printf ("", i_yw < yw, ": i_yw < yw");
assert_printf ("", i_yc < yc, ": i_yc < yc");
mem.y[iy_nhwc] = out_val;
} else {
if (ib == N_BUNDLES-1)
mem.y[iy_nhwc] = out_val; // Last bundle: save as NHWC
else {

// Other bundles: pad & save as tiled
int yr_sweep = i_yh==yh-1 ? PE_ROWS : i_yr + 1;
Expand Down Expand Up @@ -158,14 +158,14 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c
//New iw_kw2:
int w_last = iw_kw2 == pb->w_kw2-1 ? pb->kw/2+1 : 1;
int sram_addr=0;
for (int icoe=0; icoe<pb->coe; icoe++) {
for (int icoe=0; icoe < pb->coe; icoe++) {
int i_bias = it_bias + icoe;

for (int iw_last=0; iw_last<w_last; iw_last++) {
for (int ir=0; ir<PE_ROWS; ir++) {
// Indexing: [b, p, t, n, l, w | coe, w_last, r]

#define DBG "--- ib:%d ip:%d it:%d in:%d il:%d iw_kw2:%d icoe:%d iw_last:%d ir:%d \n",ib,ip,it,in,il,iw_kw2,icoe,iw_last,ir
#define DEBUG_INFO "--- ib:%d ip:%d it:%d in:%d il:%d iw_kw2:%d icoe:%d iw_last:%d ir:%d \n",ib,ip,it,in,il,iw_kw2,icoe,iw_last,ir

int raw_val=0, out_val=0;

Expand Down Expand Up @@ -193,9 +193,8 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c

PROCESS_START:


// ------ ADD P PASSES ------
iy_nhwc = ((i_yn*yh + i_yh)*yw + i_yw)*yc + i_yc;
iy_nhwc = flatten_nhwc(i_yn,i_yh,i_yw,i_yc, yn,yh,yw,yc, "Before add P passes", DEBUG_INFO);

if (pb->p == 1) { // only p : proceed with value
} else if (ip == pb->p-1) {// last p : read, add, proceed
Expand Down Expand Up @@ -231,7 +230,6 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c
out_val = quant_lrelu(out_val, pb->ca_nzero, pb->ca_shift, pb->ca_pl_scale);



// ------ SOFTMAX ------


Expand All @@ -242,12 +240,7 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c
goto PROCESS_AND_STORE_DONE;
}

assert_printf ("write_temp", i_yn < yn, ": i_yn < yn");
assert_printf ("write_temp", i_yh < yh, ": i_yh < yh");
assert_printf ("write_temp", i_yw < yw, ": i_yw < yw");
assert_printf ("write_temp", i_yc < yc, ": i_yc < yc");

iy_nhwc = ((i_yn*yh + i_yh)*yw + i_yw)*yc + i_yc; // store as nhwc for pooling
iy_nhwc = flatten_nhwc(i_yn,i_yh,i_yw,i_yc, yn,yh,yw,yc, "Before maxpool", DEBUG_INFO);// store as nhwc for pooling
mem.nhwc[iy_nhwc] = out_val;

div_ixh = div(i_yh+pb->psh_shift-pb->pkh+1, pb->psh);
Expand All @@ -259,7 +252,7 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c
goto PROCESS_AND_STORE_DONE;

// Pool Striding
if (div_ixh.rem != 0) // invalid ixh
if (div_ixh.rem != 0) // invalid ixh
if (i_yh==yh-1) ixh_beg += 1; //but last yh. start sweeping
else goto PROCESS_AND_STORE_DONE; // not last yh. skip

Expand All @@ -284,12 +277,8 @@ extern EXT_C void load_y (unsigned char *p_done, unsigned char *pt_done_proc, c
for (int ipyh = ph_end; ipyh > ph_beg; ipyh--){
for (int ipyw = pw_end; ipyw > pw_beg; ipyw--){

assert_printf ("read", i_yn < yn, ": i_yn < yn");
assert_printf ("read", ipyh < yh, ": ipyh < yh");
assert_printf ("read", ipyw < yw, ": ipyw < yw");
assert_printf ("read", i_yc < yc, ": i_yc < yc");

int read_val = mem.nhwc[((i_yn*yh + ipyh)*yw + ipyw)*yc + i_yc];
int read_idx = flatten_nhwc(i_yn, ipyh, ipyw, i_yc, yn, yh, yw, yc, "Inside pool window", DEBUG_INFO);
int read_val = mem.nhwc[read_idx];
result = pb->pool==POOL_MAX ? max(result, read_val) : (result + read_val);
}
}
Expand Down

0 comments on commit 14ac970

Please sign in to comment.