From 2cd26cd07f2e3d58cb180f3a305aab3c26b9a529 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Fri, 6 Oct 2023 13:05:48 -0400 Subject: [PATCH] Rewrite scal (#26) --- src/common/machines/avx2.py | 40 +++++++ src/common/machines/machine.py | 2 + src/common/machines/neon.py | 2 + src/level1/scal.py | 193 ++++++++------------------------- 4 files changed, 87 insertions(+), 150 deletions(-) diff --git a/src/common/machines/avx2.py b/src/common/machines/avx2.py index 8e565c17..3ac7aa6b 100644 --- a/src/common/machines/avx2.py +++ b/src/common/machines/avx2.py @@ -530,6 +530,44 @@ def mm256_prefix_add_pd( out[i] = x[i] + y[i] +@instr( + """ +{{ +__m256i indices = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); +__m256i prefix = _mm256_set1_epi32({bound}); +__m256i cmp = _mm256_cmpgt_epi32(prefix, indices); +{dst_data} = _mm256_blendv_ps ({dst_data}, _mm256_setzero_ps(), _mm256_castsi256_ps(cmp)); +}} +""" +) +def mm256_prefix_setzero_ps(dst: [f32][8] @ AVX2, bound: size): + assert stride(dst, 0) == 1 + assert bound <= 8 + + for i in seq(0, 8): + if i < bound: + dst[i] = 0.0 + + +@instr( + """ +{{ +__m256i indices = _mm256_set_epi64x(3, 2, 1, 0); +__m256i prefix = _mm256_set1_epi64x({bound}); +__m256i cmp = _mm256_cmpgt_epi64(prefix, indices); +{dst_data} = _mm256_blendv_pd ({dst_data}, _mm256_setzero_pd(), _mm256_castsi256_pd(cmp)); +}} +""" +) +def mm256_prefix_setzero_pd(dst: [f64][4] @ AVX2, bound: size): + assert stride(dst, 0) == 1 + assert bound <= 4 + + for i in seq(0, 4): + if i < bound: + dst[i] = 0.0 + + Machine = MachineParameters( name="avx2", mem_type=AVX2, @@ -552,6 +590,7 @@ def mm256_prefix_add_pd( fmadd_instr_f32=mm256_fmadd_ps, prefix_fmadd_instr_f32=mm256_prefix_fmadd_ps, set_zero_instr_f32=mm256_setzero_ps, + prefix_set_zero_instr_f32=mm256_prefix_setzero_ps, assoc_reduce_add_instr_f32=avx2_assoc_reduce_add_ps, assoc_reduce_add_f32_buffer=avx2_assoc_reduce_add_ps_buffer, mul_instr_f32=mm256_mul_ps, @@ -579,6 +618,7 @@ def mm256_prefix_add_pd( fmadd_instr_f64=mm256_fmadd_pd, prefix_fmadd_instr_f64=mm256_prefix_fmadd_pd, set_zero_instr_f64=mm256_setzero_pd, + prefix_set_zero_instr_f64=mm256_prefix_setzero_pd, assoc_reduce_add_instr_f64=avx2_assoc_reduce_add_pd, mul_instr_f64=mm256_mul_pd, prefix_mul_instr_f64=mm256_prefix_mul_pd, diff --git a/src/common/machines/machine.py b/src/common/machines/machine.py index 9b1dc60b..96b25ea2 100644 --- a/src/common/machines/machine.py +++ b/src/common/machines/machine.py @@ -37,6 +37,7 @@ class MachineParameters: fmadd_instr_f32: Any prefix_fmadd_instr_f32: Any set_zero_instr_f32: Any + prefix_set_zero_instr_f32: Any assoc_reduce_add_instr_f32: Any mul_instr_f32: Any prefix_mul_instr_f32: Any @@ -65,6 +66,7 @@ class MachineParameters: fmadd_instr_f64: Any prefix_fmadd_instr_f64: Any set_zero_instr_f64: Any + prefix_set_zero_instr_f64: Any assoc_reduce_add_instr_f64: Any mul_instr_f64: Any prefix_mul_instr_f64: Any diff --git a/src/common/machines/neon.py b/src/common/machines/neon.py index dae7e5d3..658660bf 100644 --- a/src/common/machines/neon.py +++ b/src/common/machines/neon.py @@ -101,6 +101,7 @@ def neon_vst_2xf64_backwards(dst: [f64][2] @ DRAM, src: [f64][2] @ Neon): fmadd_instr_f32=neon_vfmadd_4xf32_4xf32, prefix_fmadd_instr_f32=None, set_zero_instr_f32=neon_zero_4xf32, + prefix_set_zero_instr_f32=None, assoc_reduce_add_instr_f32=neon_assoc_reduce_add_instr_4xf32, mul_instr_f32=neon_vmul_4xf32, prefix_mul_instr_f32=None, @@ -128,6 +129,7 @@ def neon_vst_2xf64_backwards(dst: [f64][2] @ DRAM, src: [f64][2] @ Neon): fmadd_instr_f64=neon_vfmadd_2xf64_2xf64, prefix_fmadd_instr_f64=None, set_zero_instr_f64=neon_zero_2xf64, + prefix_set_zero_instr_f64=None, assoc_reduce_add_instr_f64=neon_assoc_reduce_add_instr_2xf64, assoc_reduce_add_f64_buffer=neon_assoc_reduce_add_instr_2xf64_buffer, mul_instr_f64=neon_vmul_2xf64, diff --git a/src/level1/scal.py b/src/level1/scal.py index a66ec42c..81865e17 100644 --- a/src/level1/scal.py +++ b/src/level1/scal.py @@ -9,12 +9,16 @@ import exo_blas_config as C from composed_schedules import ( - vectorize_to_loops, - interleave_execution, - hoist_stmt, apply_to_block, + hoist_stmt, ) - +from blas_composed_schedules import blas_vectorize +from codegen_helpers import ( + generate_stride_any_proc, + export_exo_proc, + generate_stride_1_proc, +) +from parameters import Level_1_Params ### EXO_LOC ALGORITHM START ### @proc @@ -33,156 +37,45 @@ def scal_template_alpha_0(n: size, x: [R][n]): ### EXO_LOC SCHEDULE START ### -def specialize_scal(precision, alpha): - prefix = "s" if precision == "f32" else "d" - specialized_scal = scal_template if alpha != 0 else scal_template_alpha_0 - specialized_scal_name = specialized_scal.name() - specialized_scal_name = specialized_scal_name.replace("_template", "") - specialized_scal = rename(specialized_scal, "exo_" + prefix + specialized_scal_name) - - args = ["x"] - if alpha != 0: - args.append("alpha") - - for arg in args: - specialized_scal = set_precision(specialized_scal, arg, precision) - return specialized_scal - - -def schedule_scal_stride_1( - VEC_W, INTERLEAVE_FACTOR, memory, instructions, precision, alpha -): - simple_stride_1 = specialize_scal(precision, alpha) - simple_stride_1 = rename(simple_stride_1, simple_stride_1.name() + "_stride_1") - simple_stride_1 = simple_stride_1.add_assertion("stride(x, 0) == 1") - - main_loop = simple_stride_1.find_loop("i") - simple_stride_1 = vectorize_to_loops( - simple_stride_1, main_loop, VEC_W, memory, precision +def schedule_scal_stride_1(scal, params): + scal = generate_stride_1_proc(scal, params.precision) + main_loop = scal.find_loop("i") + scal = blas_vectorize(scal, main_loop, params) + main_loop = scal.find_loop("ioo") + scal = add_unsafe_guard( + scal, + main_loop.as_block(), + FormattedExprStr("_ < _", main_loop.lo(), main_loop.hi()), ) - simple_stride_1 = interleave_execution( - simple_stride_1, simple_stride_1.find_loop("io"), INTERLEAVE_FACTOR + main_loop = scal.find_loop("ioo") + scal = apply_to_block(scal, main_loop.body(), hoist_stmt) + middle_loop = scal.find_loop("ioi") + scal = add_unsafe_guard( + scal, + middle_loop.as_block(), + FormattedExprStr("_ < _", middle_loop.lo(), middle_loop.hi()), ) - simple_stride_1 = apply_to_block( - simple_stride_1, simple_stride_1.find_loop("ioo").body(), hoist_stmt - ) - simple_stride_1 = replace_all(simple_stride_1, instructions) - return simple_stride_1 - - -################################################# -# Generate specialized kernels for f32 precision -################################################# + middle_loop = scal.find_loop("ioi") + scal = apply_to_block(scal, middle_loop.body(), hoist_stmt) + return simplify(scal) -INTERLEAVE_FACTOR = C.Machine.vec_units * 2 -exo_sscal_stride_any = specialize_scal("f32", None) -exo_sscal_stride_any = rename( - exo_sscal_stride_any, exo_sscal_stride_any.name() + "_stride_any" -) -exo_sscal_alpha_0_stride_any = specialize_scal("f32", 0) -exo_sscal_alpha_0_stride_any = rename( - exo_sscal_alpha_0_stride_any, exo_sscal_alpha_0_stride_any.name() + "_stride_any" -) - -f32_instructions = [ - C.Machine.load_instr_f32, - C.Machine.store_instr_f32, - C.Machine.mul_instr_f32, - C.Machine.broadcast_scalar_instr_f32, - C.Machine.set_zero_instr_f32, +template_sched_list = [ + (scal_template, schedule_scal_stride_1), + (scal_template_alpha_0, schedule_scal_stride_1), ] -if None not in f32_instructions: - exo_sscal_stride_1 = schedule_scal_stride_1( - C.Machine.vec_width, - INTERLEAVE_FACTOR, - C.Machine.mem_type, - f32_instructions, - "f32", - None, - ) - exo_sscal_alpha_0_stride_1 = schedule_scal_stride_1( - C.Machine.vec_width, - INTERLEAVE_FACTOR, - C.Machine.mem_type, - f32_instructions, - "f32", - 0, - ) -else: - exo_sscal_stride_1 = specialize_scal("f32", None) - exo_sscal_stride_1 = rename( - exo_sscal_stride_1, exo_sscal_stride_1.name() + "_stride_1" - ) - exo_sscal_alpha_0_stride_1 = specialize_scal("f32", 0) - exo_sscal_alpha_0_stride_1 = rename( - exo_sscal_alpha_0_stride_1, exo_sscal_alpha_0_stride_1.name() + "_stride_1" - ) - -################################################# -# Generate specialized kernels for f64 precision -################################################# - -exo_dscal_stride_any = specialize_scal("f64", None) -exo_dscal_stride_any = rename( - exo_dscal_stride_any, exo_dscal_stride_any.name() + "_stride_any" -) -exo_dscal_alpha_0_stride_any = specialize_scal("f64", 0) -exo_dscal_alpha_0_stride_any = rename( - exo_dscal_alpha_0_stride_any, exo_dscal_alpha_0_stride_any.name() + "_stride_any" -) - -f64_instructions = [ - C.Machine.load_instr_f64, - C.Machine.store_instr_f64, - C.Machine.mul_instr_f64, - C.Machine.broadcast_scalar_instr_f64, - C.Machine.set_zero_instr_f64, -] - -if None not in f64_instructions: - exo_dscal_stride_1 = schedule_scal_stride_1( - C.Machine.vec_width // 2, - INTERLEAVE_FACTOR, - C.Machine.mem_type, - f64_instructions, - "f64", - None, - ) - exo_dscal_alpha_0_stride_1 = schedule_scal_stride_1( - C.Machine.vec_width // 2, - INTERLEAVE_FACTOR, - C.Machine.mem_type, - f64_instructions, - "f64", - 0, - ) -else: - exo_dscal_stride_1 = specialize_scal("f64", None) - exo_dscal_stride_1 = rename( - exo_dscal_stride_1, exo_dscal_stride_1.name() + "_stride_1" - ) - exo_dscal_alpha_0_stride_1 = specialize_scal("f64", 0) - exo_dscal_alpha_0_stride_1 = rename( - exo_dscal_alpha_0_stride_1, exo_dscal_alpha_0_stride_1.name() + "_stride_1" - ) +# TODO: Debug alpha zero case + +for precision in ("f32", "f64"): + for template, sched in template_sched_list: + proc_stride_any = generate_stride_any_proc(template, precision) + export_exo_proc(globals(), proc_stride_any) + proc_stride_1 = sched( + template, + Level_1_Params( + precision=precision, accumulators_count=1, interleave_factor=4 + ), + ) + export_exo_proc(globals(), proc_stride_1) ### EXO_LOC SCHEDULE END ### - -entry_points = [ - exo_sscal_stride_any, - exo_sscal_stride_1, - exo_sscal_alpha_0_stride_1, - exo_sscal_alpha_0_stride_any, - exo_dscal_stride_any, - exo_dscal_stride_1, - exo_dscal_alpha_0_stride_1, - exo_dscal_alpha_0_stride_any, -] - - -if __name__ == "__main__": - for p in entry_points: - print(p) - -__all__ = [p.name() for p in entry_points]