diff --git a/src/common/composed_schedules.py b/src/common/composed_schedules.py index 901398f9..40dcad82 100644 --- a/src/common/composed_schedules.py +++ b/src/common/composed_schedules.py @@ -709,7 +709,7 @@ def vectorize( # Now that we have a tail loop, the conditional in the main loop # can be removed - proc = remove_if(proc, inner_loop_cursor.body()[0]) + proc = eliminate_dead_code(proc, inner_loop_cursor.body()[0]) proc = vectorize_to_loops( proc, tail_loop_cursor, vec_width, memory_type, precision diff --git a/src/level1/asum.py b/src/level1/asum.py index ea130033..4111c6b0 100644 --- a/src/level1/asum.py +++ b/src/level1/asum.py @@ -35,7 +35,7 @@ def asum(n: size, x: [f32][n] @ DRAM, result: f32 @ DRAM): def schedule_asum_stride_1(asum, params): asum = generate_stride_1_proc(asum, params.precision) - if not isinstance(params.mem_type, AVX2): + if params.mem_type is not AVX2: return asum loop = asum.find_loop("i") @@ -52,7 +52,7 @@ def schedule_asum_stride_1(asum, params): loop = asum.forward(loop) asum = cut_loop(asum, loop, FormattedExprStr("_ - 1", loop.hi())) - asum = remove_if(asum, loop.body()[0].body()[0]) + asum = eliminate_dead_code(asum, loop.body()[0].body()[0]) asum = auto_stage_mem(asum, asum.find("x[_]"), "xReg") asum = stage_expr(asum, asum.find("select(_)"), "selectReg") asum = simplify(asum)