diff --git a/src/common/codegen_helpers.py b/src/common/codegen_helpers.py index eea7b8c5..973cb003 100644 --- a/src/common/codegen_helpers.py +++ b/src/common/codegen_helpers.py @@ -9,44 +9,53 @@ from exo.API_cursors import * from introspection import * +from higher_order import * +import exo_blas_config as C -def specialize_precision(template_proc, precision): +def specialize_precision(proc, precision, all_buffs=True): + assert precision in {"f32", "f64"} prefix = "s" if precision == "f32" else "d" - template_name = template_proc.name() - template_name = template_name.replace("_template", "") - specialized_proc = rename(template_proc, "exo_" + prefix + template_name) + template_name = proc.name() + proc = rename(proc, prefix + template_name) - for arg in template_proc.args(): - if arg.type().is_numeric(): - specialized_proc = set_precision(specialized_proc, arg, precision) + def has_type_R(proc, s, *arg): + if not isinstance(s, (AllocCursor, ArgCursor)): + return False + return s.type() == ExoType.R - for stmt in lrn_stmts(template_proc): - if isinstance(stmt, AllocCursor): - specialized_proc = set_precision(specialized_proc, stmt, precision) - return specialized_proc + set_R_type = predicate(set_precision, has_type_R) + def is_numeric(proc, s, *arg): + if not isinstance(s, (AllocCursor, ArgCursor)): + return False + return s.type().is_numeric() -def generate_stride_any_proc(template_proc, precision): - proc = specialize_precision(template_proc, precision) + set_numerics = predicate(set_precision, is_numeric) + + set_type = set_numerics if all_buffs else set_R_type + proc = apply(set_type)(proc, proc.args(), precision) + proc = make_pass(set_type)(proc, proc.body(), precision) + return proc + + +def generate_stride_any_proc(proc): proc = rename(proc, proc.name() + "_stride_any") - proc = stage_scalar_args(proc) return proc -def generate_stride_1_proc(template_proc, precision): - proc = specialize_precision(template_proc, precision) +def generate_stride_1_proc(proc): proc = rename(proc, proc.name() + "_stride_1") for arg in proc.args(): if arg.is_tensor(): proc = proc.add_assertion( f"stride({arg.name()}, {len(arg.shape()) - 1}) == 1" ) - proc = stage_scalar_args(proc) return proc def export_exo_proc(globals, proc): + proc = rename(proc, f"exo_{proc.name()}") globals[proc.name()] = simplify(proc) globals.setdefault("__all__", []).append(proc.name()) @@ -70,3 +79,23 @@ def stage_scalar_args(proc): if arg.type().is_numeric() and not arg.is_tensor(): proc = stage_mem(proc, proc.body(), arg.name(), f"{arg.name()}_") return proc + + +def variants_generator(blas_op): + def generate(proc, loop_name, *args, globals=None): + for precision in ("f32", "f64"): + proc_variant = specialize_precision(proc, precision) + + proc_variant = stage_scalar_args(proc_variant) + + stride_any = generate_stride_any_proc(proc_variant) + stride_any = bind_builtins_args(stride_any, stride_any.body(), precision) + export_exo_proc(globals, stride_any) + + stride_1 = generate_stride_1_proc(proc_variant) + loop = stride_1.find_loop(loop_name) + stride_1 = blas_op(stride_1, loop, precision, C.Machine, *args) + stride_1 = bind_builtins_args(stride_1, stride_1.body(), precision) + export_exo_proc(globals, stride_1) + + return generate diff --git a/src/level1/asum.py b/src/level1/asum.py index e932b669..55342a6c 100644 --- a/src/level1/asum.py +++ b/src/level1/asum.py @@ -1,13 +1,8 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC from exo.platforms.x86 import * -from exo.syntax import * -from exo.stdlib.scheduling import * -import exo_blas_config as C -from composed_schedules import * from blaslib import * from codegen_helpers import * @@ -22,32 +17,11 @@ def asum(n: size, x: [f32][n] @ DRAM, result: f32 @ DRAM): ### EXO_LOC ALGORITHM END ### ### EXO_LOC SCHEDULE START ### -def schedule_asum_stride_1(asum, precision): - asum = generate_stride_1_proc(asum, precision) - if C.Machine.mem_type is not AVX2: +def schedule_asum(asum, loop, precision, machine, interleave_factor): + if machine.mem_type is not AVX2: return asum - asum = optimize_level_1(asum, asum.find_loop("i"), precision, C.Machine, 7) - return asum - - -template_sched_list = [ - (asum, schedule_asum_stride_1), -] - -for precision in ("f32", "f64"): - for template, sched in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - proc_stride_any = bind_builtins_args( - proc_stride_any, proc_stride_any.body(), precision - ) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = sched( - template, - precision, - ) - proc_stride_1 = bind_builtins_args( - proc_stride_1, proc_stride_1.body(), precision - ) - export_exo_proc(globals(), proc_stride_1) + return optimize_level_1(asum, loop, precision, machine, interleave_factor) + +variants_generator(schedule_asum)(asum, "i", 7, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level1/axpy.py b/src/level1/axpy.py index 61ad6fcc..6c755700 100644 --- a/src/level1/axpy.py +++ b/src/level1/axpy.py @@ -1,26 +1,19 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC -from exo.platforms.x86 import * -from exo.syntax import * -from exo.stdlib.scheduling import * -from exo.API_cursors import * - -import exo_blas_config as C -from composed_schedules import * + from blaslib import * from codegen_helpers import * ### EXO_LOC ALGORITHM START ### @proc -def axpy_template(n: size, alpha: R, x: [R][n], y: [R][n]): +def axpy(n: size, alpha: R, x: [R][n], y: [R][n]): for i in seq(0, n): y[i] += alpha * x[i] @proc -def axpy_template_alpha_1(n: size, x: [R][n], y: [R][n]): +def axpy_alpha_1(n: size, x: [R][n], y: [R][n]): for i in seq(0, n): y[i] += x[i] @@ -29,22 +22,6 @@ def axpy_template_alpha_1(n: size, x: [R][n], y: [R][n]): ### EXO_LOC SCHEDULE START ### -def schedule_axpy_stride_1(axpy, precision): - axpy = generate_stride_1_proc(axpy, precision) - main_loop = axpy.find_loop("i") - axpy = optimize_level_1(axpy, main_loop, precision, C.Machine, 4) - return simplify(axpy) - - -template_sched_list = [ - (axpy_template, schedule_axpy_stride_1), - (axpy_template_alpha_1, schedule_axpy_stride_1), -] - -for precision in ("f32", "f64"): - for template, sched in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = sched(template, precision) - export_exo_proc(globals(), proc_stride_1) +for proc in axpy, axpy_alpha_1: + variants_generator(optimize_level_1)(proc, "i", 4, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level1/dot.py b/src/level1/dot.py index da7392e7..78f2daec 100644 --- a/src/level1/dot.py +++ b/src/level1/dot.py @@ -1,21 +1,14 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC -from exo.platforms.x86 import * -from exo.syntax import * -from exo.stdlib.scheduling import * -import exo.API_cursors as pc - -import exo_blas_config as C -from composed_schedules import * + from blaslib import * from codegen_helpers import * ### EXO_LOC ALGORITHM START ### @proc -def dot_template(n: size, x: [R][n], y: [R][n], result: R): +def dot(n: size, x: [R][n], y: [R][n], result: R): result = 0.0 for i in seq(0, n): result += x[i] * y[i] @@ -25,21 +18,5 @@ def dot_template(n: size, x: [R][n], y: [R][n], result: R): ### EXO_LOC SCHEDULE START ### -def schedule_dot_stride_1(dot, precision): - dot = generate_stride_1_proc(dot, precision) - main_loop = dot.find_loop("i") - dot = optimize_level_1(dot, main_loop, precision, C.Machine, 4) - return simplify(dot) - - -template_sched_list = [ - (dot_template, schedule_dot_stride_1), -] - -for precision in ("f32", "f64"): - for template, sched in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = sched(template, precision) - export_exo_proc(globals(), proc_stride_1) +variants_generator(optimize_level_1)(dot, "i", 4, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level1/dsdot.py b/src/level1/dsdot.py index 8eee2452..f13ddfa6 100644 --- a/src/level1/dsdot.py +++ b/src/level1/dsdot.py @@ -1,20 +1,15 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC from exo.platforms.x86 import * -from exo.platforms.neon import * -from exo.syntax import * -from exo.stdlib.scheduling import * -import exo_blas_config as C -from composed_schedules import * from blaslib import * from codegen_helpers import * +import exo_blas_config as C ### EXO_LOC ALGORITHM START ### @proc -def dsdot_template(n: size, x: [f32][n], y: [f32][n], result: f64): +def dsdot(n: size, x: [f32][n], y: [f32][n], result: f64): d_dot: f64 d_dot = 0.0 for i in seq(0, n): @@ -27,7 +22,7 @@ def dsdot_template(n: size, x: [f32][n], y: [f32][n], result: f64): @proc -def sdsdot_template(n: size, sb: f32, x: [f32][n], y: [f32][n], result: f32): +def sdsdot(n: size, sb: f32, x: [f32][n], y: [f32][n], result: f32): d_result: f64 d_dot: f64 d_dot = 0.0 @@ -46,26 +41,11 @@ def sdsdot_template(n: size, sb: f32, x: [f32][n], y: [f32][n], result: f32): ### EXO_LOC SCHEDULE START ### -def schedule_dsdot_stride_1(proc, name): - proc = rename(proc, name) - proc = proc.add_assertion("stride(x, 0) == 1") - proc = proc.add_assertion("stride(y, 0) == 1") - - if C.Machine.mem_type is not AVX2: - return proc - proc = optimize_level_1(proc, proc.find_loop("i"), "f64", C.Machine, 4) - return proc - - -export_exo_proc(globals(), rename(dsdot_template, "exo_dsdot_stride_any")) -export_exo_proc(globals(), rename(sdsdot_template, "exo_sdsdot_stride_any")) -export_exo_proc( - globals(), - schedule_dsdot_stride_1(dsdot_template, "exo_dsdot_stride_1"), -) -export_exo_proc( - globals(), - schedule_dsdot_stride_1(sdsdot_template, "exo_sdsdot_stride_1"), -) - +for proc in dsdot, sdsdot: + export_exo_proc(globals(), generate_stride_any_proc(proc)) + stride_1 = generate_stride_1_proc(proc) + if C.Machine.mem_type is AVX2: + loop = stride_1.find_loop("i") + stride_1 = optimize_level_1(stride_1, loop, "f64", C.Machine, 4) + export_exo_proc(globals(), stride_1) ### EXO_LOC SCHEDULE END ### diff --git a/src/level1/exo_copy.py b/src/level1/exo_copy.py index 7470d799..5588ac64 100644 --- a/src/level1/exo_copy.py +++ b/src/level1/exo_copy.py @@ -1,19 +1,14 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC -from exo.platforms.x86 import * -from exo.syntax import * -from exo.stdlib.scheduling import * from blaslib import * from codegen_helpers import * -import exo_blas_config as C ### EXO_LOC ALGORITHM START ### @proc -def copy_template(n: size, x: [R][n], y: [R][n]): +def copy(n: size, x: [R][n], y: [R][n]): for i in seq(0, n): y[i] = x[i] @@ -21,22 +16,5 @@ def copy_template(n: size, x: [R][n], y: [R][n]): ### EXO_LOC ALGORITHM END ### ### EXO_LOC SCHEDULE START ### -def schedule_copy(exo_copy, precision): - exo_copy = generate_stride_1_proc(exo_copy, precision) - main_loop = exo_copy.find_loop("i") - exo_copy = optimize_level_1(exo_copy, main_loop, precision, C.Machine, 4) - return simplify(exo_copy) - - -template_sched_list = [ - (copy_template, schedule_copy), -] - -for precision in ("f32", "f64"): - for template, sched in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = sched(template, precision) - export_exo_proc(globals(), proc_stride_1) - +variants_generator(optimize_level_1)(copy, "i", 4, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level1/nrm2.py b/src/level1/nrm2.py index 02df5625..3adea736 100644 --- a/src/level1/nrm2.py +++ b/src/level1/nrm2.py @@ -11,7 +11,7 @@ ### EXO_LOC ALGORITHM START ### @proc -def nrm2_template(n: size, x: [R][n], result: R): +def nrm2(n: size, x: [R][n], result: R): result = 0.0 for i in seq(0, n): result += x[i] * x[i] @@ -23,7 +23,7 @@ def nrm2_template(n: size, x: [R][n], result: R): ### EXO_LOC SCHEDULE START ### def specialize_precision(precision): prefix = "s" if precision == "f32" else "d" - specialized_copy = rename(nrm2_template, "exo_" + prefix + "nrm2") + specialized_copy = rename(nrm2, "exo_" + prefix + "nrm2") for arg in ["x", "result"]: specialized_copy = set_precision(specialized_copy, arg, precision) return specialized_copy diff --git a/src/level1/rot.py b/src/level1/rot.py index 24397fb2..05c7ea61 100644 --- a/src/level1/rot.py +++ b/src/level1/rot.py @@ -1,20 +1,13 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC -from exo.platforms.x86 import * -from exo.syntax import * -from exo.stdlib.scheduling import * -import exo.API_cursors as pc -import exo_blas_config as C -from composed_schedules import * from blaslib import * from codegen_helpers import * ### EXO_LOC ALGORITHM START ### @proc -def rot_template(n: size, x: [R][n], y: [R][n], c: R, s: R): +def rot(n: size, x: [R][n], y: [R][n], c: R, s: R): for i in seq(0, n): xReg: R xReg = x[i] @@ -26,22 +19,5 @@ def rot_template(n: size, x: [R][n], y: [R][n], c: R, s: R): ### EXO_LOC SCHEDULE START ### -def schedule_rot_stride_1(rot, precision): - rot = generate_stride_1_proc(rot, precision) - loop_cursor = rot.find_loop("i") - rot = optimize_level_1(rot, loop_cursor, precision, C.Machine, 4) - return rot - - -template_sched_list = [ - (rot_template, schedule_rot_stride_1), -] - -for precision in ("f32", "f64"): - for template, sched in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = sched(template, precision) - export_exo_proc(globals(), proc_stride_1) - +variants_generator(optimize_level_1)(rot, "i", 4, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level1/rotm.py b/src/level1/rotm.py index 8483a8e7..cddc337d 100644 --- a/src/level1/rotm.py +++ b/src/level1/rotm.py @@ -1,19 +1,13 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC -from exo.platforms.x86 import * -from exo.syntax import * -from exo.stdlib.scheduling import * -import exo_blas_config as C -from composed_schedules import * from blaslib import * from codegen_helpers import * ### EXO_LOC ALGORITHM START ### @proc -def rotm_template_flag_neg_one(n: size, x: [R][n], y: [R][n], H: R[2, 2]): +def rotm_flag_neg_one(n: size, x: [R][n], y: [R][n], H: R[2, 2]): for i in seq(0, n): xReg: R xReg = x[i] @@ -22,7 +16,7 @@ def rotm_template_flag_neg_one(n: size, x: [R][n], y: [R][n], H: R[2, 2]): @proc -def rotm_template_flag_zero(n: size, x: [R][n], y: [R][n], H: R[2, 2]): +def rotm_flag_zero(n: size, x: [R][n], y: [R][n], H: R[2, 2]): for i in seq(0, n): xReg: R xReg = x[i] @@ -31,7 +25,7 @@ def rotm_template_flag_zero(n: size, x: [R][n], y: [R][n], H: R[2, 2]): @proc -def rotm_template_flag_one(n: size, x: [R][n], y: [R][n], H: R[2, 2]): +def rotm_flag_one(n: size, x: [R][n], y: [R][n], H: R[2, 2]): for i in seq(0, n): xReg: R xReg = x[i] @@ -43,26 +37,6 @@ def rotm_template_flag_one(n: size, x: [R][n], y: [R][n], H: R[2, 2]): ### EXO_LOC SCHEDULE START ### - - -def schedule_rotm_stride_1(rotm, precision): - rotm = generate_stride_1_proc(rotm, precision) - loop_cursor = rotm.find_loop("i") - rotm = optimize_level_1(rotm, loop_cursor, precision, C.Machine, 4) - return rotm - - -template_sched_list = [ - (rotm_template_flag_neg_one, schedule_rotm_stride_1), - (rotm_template_flag_zero, schedule_rotm_stride_1), - (rotm_template_flag_one, schedule_rotm_stride_1), -] - -for precision in ("f32", "f64"): - for template, sched in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = sched(template, precision) - export_exo_proc(globals(), proc_stride_1) - +for proc in rotm_flag_neg_one, rotm_flag_zero, rotm_flag_one: + variants_generator(optimize_level_1)(proc, "i", 4, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level1/scal.py b/src/level1/scal.py index 94f7c097..83e90f0d 100644 --- a/src/level1/scal.py +++ b/src/level1/scal.py @@ -1,26 +1,19 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC -from exo.platforms.x86 import * -from exo.syntax import * -from exo.stdlib.scheduling import * -import exo.API_cursors as pc - -import exo_blas_config as C -from composed_schedules import * + from blaslib import * from codegen_helpers import * ### EXO_LOC ALGORITHM START ### @proc -def scal_template(n: size, alpha: R, x: [R][n]): +def scal(n: size, alpha: R, x: [R][n]): for i in seq(0, n): x[i] = alpha * x[i] @proc -def scal_template_alpha_0(n: size, x: [R][n]): +def scal_alpha_0(n: size, x: [R][n]): for i in seq(0, n): x[i] = 0.0 @@ -29,24 +22,6 @@ def scal_template_alpha_0(n: size, x: [R][n]): ### EXO_LOC SCHEDULE START ### -def schedule_scal_stride_1(scal, precision): - scal = generate_stride_1_proc(scal, precision) - main_loop = scal.find_loop("i") - scal = optimize_level_1(scal, main_loop, precision, C.Machine, 4) - return simplify(scal) - - -template_sched_list = [ - (scal_template, schedule_scal_stride_1), - (scal_template_alpha_0, schedule_scal_stride_1), -] - -# TODO: Debug alpha zero case - -for precision in ("f32", "f64"): - for template, sched in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = sched(template, precision) - export_exo_proc(globals(), proc_stride_1) +for proc in scal, scal_alpha_0: + variants_generator(optimize_level_1)(proc, "i", 4, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level1/swap.py b/src/level1/swap.py index 33dca430..e7d32f20 100644 --- a/src/level1/swap.py +++ b/src/level1/swap.py @@ -1,18 +1,13 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC -from exo.platforms.x86 import * -from exo.syntax import * -from exo.stdlib.scheduling import * -import exo_blas_config as C from blaslib import * from codegen_helpers import * ### EXO_LOC ALGORITHM START ### @proc -def swap_template(n: size, x: [R][n], y: [R][n]): +def swap(n: size, x: [R][n], y: [R][n]): for i in seq(0, n): tmp: R tmp = x[i] @@ -24,22 +19,5 @@ def swap_template(n: size, x: [R][n], y: [R][n]): ### EXO_LOC SCHEDULE START ### -def schedule_swap(swap, precision): - swap = generate_stride_1_proc(swap, precision) - main_loop = swap.find_loop("i") - swap = optimize_level_1(swap, main_loop, precision, C.Machine, 4) - return simplify(swap) - - -template_sched_list = [ - (swap_template, schedule_swap), -] - -for precision in ("f32", "f64"): - for template, sched in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = sched(template, precision) - export_exo_proc(globals(), proc_stride_1) - +variants_generator(optimize_level_1)(swap, "i", 4, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level2/gbmv.py b/src/level2/gbmv.py index a545bce4..5a64e211 100644 --- a/src/level2/gbmv.py +++ b/src/level2/gbmv.py @@ -6,7 +6,7 @@ from exo.syntax import * from exo.stdlib.scheduling import * -from dot import exo_sdot_stride_1, dot_template, exo_ddot_stride_1 +from dot import exo_sdot_stride_1, dot, exo_ddot_stride_1 import exo_blas_config as C from composed_schedules import ( scalar_to_simd, @@ -64,7 +64,7 @@ def gbmv_row_major_NonTrans( ### EXO_LOC SCHEDULE START ### def specialize_sdot(precision): - specialized = sdot_template + specialized = sdot for arg in ["x", "y", "result"]: specialized = set_precision(specialized, arg, precision) @@ -91,11 +91,11 @@ def schedule_stride_1(precision): stride_1 = stride_1.add_assertion("stride(a, 1) == 1") scheduled_sdot = exo_sdot_stride_1 if precision == "f32" else exo_ddot_stride_1 - dot_template.unsafe_assert_eq(scheduled_sdot) + dot.unsafe_assert_eq(scheduled_sdot) for i in range(4): - stride_1 = replace(stride_1, stride_1.find_loop("j").expand(1, 0), dot_template) - stride_1 = call_eqv(stride_1, "dot_template", scheduled_sdot) + stride_1 = replace(stride_1, stride_1.find_loop("j").expand(1, 0), dot) + stride_1 = call_eqv(stride_1, "dot", scheduled_sdot) return simplify(stride_1) diff --git a/src/level2/gemv.py b/src/level2/gemv.py index 9256c15f..6389eb8b 100644 --- a/src/level2/gemv.py +++ b/src/level2/gemv.py @@ -1,16 +1,7 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC -from exo.platforms.x86 import * -from exo.platforms.neon import * -from exo.syntax import * -from exo.stdlib.scheduling import * -from exo.API import compile_procs - -from blas_common_schedules import * -import exo_blas_config as C -from composed_schedules import * + from blaslib import * from codegen_helpers import * @@ -25,6 +16,11 @@ def gemv_rm_nt(m: size, n: size, alpha: R, beta: R, A: [R][m, n], x: [R][n], y: y[i] += alpha * (x[j] * A[i, j]) +### EXO_LOC ALGORITHM END ### + + +### EXO_LOC SCHEDULE START ### + gemv_rm_t = gemv_rm_nt.transpose(gemv_rm_nt.args()[4]) gemv_rm_t = rename(gemv_rm_t, "gemv_rm_t") @@ -37,24 +33,7 @@ def gemv_rm_nt(m: size, n: size, alpha: R, beta: R, A: [R][m, n], x: [R][n], y: gemv_rm_t = reorder_loops(gemv_rm_t, gemv_rm_t.find_loop("i #1")) gemv_rm_t = left_reassociate_expr(gemv_rm_t, gemv_rm_t.find("alpha * _")) -### EXO_LOC ALGORITHM END ### - - -### EXO_LOC SCHEDULE START ### - -template_sched_list = [ - (gemv_rm_nt, "i"), - (gemv_rm_t, "j"), -] - -for precision in ("f32", "f64"): - for template, it in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = generate_stride_1_proc(template, precision) - proc_stride_1 = optimize_level_2( - proc_stride_1, proc_stride_1.find_loop(it), precision, C.Machine, 4, 2 - ) - export_exo_proc(globals(), proc_stride_1) +variants_generator(optimize_level_2)(gemv_rm_nt, "i", 4, 2, globals=globals()) +variants_generator(optimize_level_2)(gemv_rm_t, "j", 4, 2, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level2/ger.py b/src/level2/ger.py index 88e5db12..6a7113a0 100644 --- a/src/level2/ger.py +++ b/src/level2/ger.py @@ -11,9 +11,7 @@ ### EXO_LOC ALGORITHM START ### @proc -def ger_row_major_template( - m: size, n: size, alpha: R, x: [R][m], y: [R][n], A: [R][m, n] -): +def ger_row_major(m: size, n: size, alpha: R, x: [R][m], y: [R][n], A: [R][m, n]): assert stride(A, 1) == 1 for i in seq(0, m): @@ -22,7 +20,7 @@ def ger_row_major_template( @proc -def ger_row_major_template_alpha_1( +def ger_row_major_alpha_1( m: size, n: size, alpha: R, x: [R][m], y: [R][n], A: [R][m, n] ): assert stride(A, 1) == 1 @@ -38,10 +36,8 @@ def ger_row_major_template_alpha_1( ### EXO_LOC SCHEDULE START ### def specialize_ger(precision, alpha): prefix = "s" if precision == "f32" else "d" - specialized = ( - ger_row_major_template if alpha != 1 else ger_row_major_template_alpha_1 - ) - name = specialized.name().replace("_template", "") + specialized = ger_row_major if alpha != 1 else ger_row_major_alpha_1 + name = specialized.name().replace("", "") specialized = rename(specialized, "exo_" + prefix + name) args = ["alpha", "x", "y", "A"] diff --git a/src/level2/sbmv.py b/src/level2/sbmv.py index 67c030cd..017dc615 100644 --- a/src/level2/sbmv.py +++ b/src/level2/sbmv.py @@ -17,7 +17,7 @@ def sbmv_scal_y(n: size, beta: R, y: [R][n]): @proc -def sbmv_row_major_Upper_template( +def sbmv_row_major_Upper( n: size, k: size, alpha: R, A: [R][n, k + 1], x: [R][n], y: [R][n] ): assert stride(A, 1) == 1 @@ -36,7 +36,7 @@ def sbmv_row_major_Upper_template( @proc -def sbmv_row_major_Lower_template( +def sbmv_row_major_Lower( n: size, k: size, alpha: R, A: [R][n, k + 1], x: [R][n], y: [R][n] ): assert stride(A, 1) == 1 @@ -57,7 +57,7 @@ def sbmv_row_major_Lower_template( def specialize_sbmv(sbmv, precision): prefix = "s" if precision == "f32" else "d" name = sbmv.name() - name = name.replace("_template", "") + name = name.replace("", "") specialized = rename(sbmv, "exo_" + prefix + name) if "scal" in sbmv.name(): @@ -104,16 +104,12 @@ def schedule_interleave_sbmv_row_major_stride_1( exo_ssbmv_scal_y_stride_any = rename( exo_ssbmv_scal_y_stride_any, exo_ssbmv_scal_y_stride_any.name() + "_stride_any" ) -exo_ssbmv_row_major_Upper_stride_any = specialize_sbmv( - sbmv_row_major_Upper_template, "f32" -) +exo_ssbmv_row_major_Upper_stride_any = specialize_sbmv(sbmv_row_major_Upper, "f32") exo_ssbmv_row_major_Upper_stride_any = rename( exo_ssbmv_row_major_Upper_stride_any, exo_ssbmv_row_major_Upper_stride_any.name() + "_stride_any", ) -exo_ssbmv_row_major_Lower_stride_any = specialize_sbmv( - sbmv_row_major_Lower_template, "f32" -) +exo_ssbmv_row_major_Lower_stride_any = specialize_sbmv(sbmv_row_major_Lower, "f32") exo_ssbmv_row_major_Lower_stride_any = rename( exo_ssbmv_row_major_Lower_stride_any, exo_ssbmv_row_major_Lower_stride_any.name() + "_stride_any", @@ -128,7 +124,7 @@ def schedule_interleave_sbmv_row_major_stride_1( ] exo_ssbmv_row_major_Upper_stride_1 = schedule_interleave_sbmv_row_major_stride_1( - sbmv_row_major_Upper_template, + sbmv_row_major_Upper, C.Machine.f32_vec_width, 1, C.Machine.mem_type, @@ -136,7 +132,7 @@ def schedule_interleave_sbmv_row_major_stride_1( "f32", ) exo_ssbmv_row_major_Lower_stride_1 = schedule_interleave_sbmv_row_major_stride_1( - sbmv_row_major_Lower_template, + sbmv_row_major_Lower, C.Machine.f32_vec_width, 1, C.Machine.mem_type, @@ -155,16 +151,12 @@ def schedule_interleave_sbmv_row_major_stride_1( exo_dsbmv_scal_y_stride_any = rename( exo_dsbmv_scal_y_stride_any, exo_dsbmv_scal_y_stride_any.name() + "_stride_any" ) -exo_dsbmv_row_major_Upper_stride_any = specialize_sbmv( - sbmv_row_major_Upper_template, "f64" -) +exo_dsbmv_row_major_Upper_stride_any = specialize_sbmv(sbmv_row_major_Upper, "f64") exo_dsbmv_row_major_Upper_stride_any = rename( exo_dsbmv_row_major_Upper_stride_any, exo_dsbmv_row_major_Upper_stride_any.name() + "_stride_any", ) -exo_dsbmv_row_major_Lower_stride_any = specialize_sbmv( - sbmv_row_major_Lower_template, "f64" -) +exo_dsbmv_row_major_Lower_stride_any = specialize_sbmv(sbmv_row_major_Lower, "f64") exo_dsbmv_row_major_Lower_stride_any = rename( exo_dsbmv_row_major_Lower_stride_any, exo_dsbmv_row_major_Lower_stride_any.name() + "_stride_any", @@ -180,7 +172,7 @@ def schedule_interleave_sbmv_row_major_stride_1( ] exo_dsbmv_row_major_Upper_stride_1 = schedule_interleave_sbmv_row_major_stride_1( - sbmv_row_major_Upper_template, + sbmv_row_major_Upper, C.Machine.f32_vec_width // 2, 1, C.Machine.mem_type, @@ -188,7 +180,7 @@ def schedule_interleave_sbmv_row_major_stride_1( "f64", ) exo_dsbmv_row_major_Lower_stride_1 = schedule_interleave_sbmv_row_major_stride_1( - sbmv_row_major_Lower_template, + sbmv_row_major_Lower, C.Machine.f32_vec_width // 2, 1, C.Machine.mem_type, diff --git a/src/level2/symv.py b/src/level2/symv.py index 4f0fd03a..a5ce2bd1 100644 --- a/src/level2/symv.py +++ b/src/level2/symv.py @@ -17,9 +17,7 @@ def symv_scal_y(n: size, beta: R, y: [R][n]): @proc -def symv_row_major_Upper_template( - n: size, alpha: R, A: [R][n, n], x: [R][n], y: [R][n] -): +def symv_row_major_Upper(n: size, alpha: R, A: [R][n, n], x: [R][n], y: [R][n]): assert stride(A, 1) == 1 for i in seq(0, n): @@ -34,9 +32,7 @@ def symv_row_major_Upper_template( @proc -def symv_row_major_Lower_template( - n: size, alpha: R, A: [R][n, n], x: [R][n], y: [R][n] -): +def symv_row_major_Lower(n: size, alpha: R, A: [R][n, n], x: [R][n], y: [R][n]): assert stride(A, 1) == 1 for i in seq(0, n): @@ -53,7 +49,7 @@ def symv_row_major_Lower_template( def specialize_symv(symv, precision): prefix = "s" if precision == "f32" else "d" name = symv.name() - name = name.replace("_template", "") + name = name.replace("", "") specialized = rename(symv, "exo_" + prefix + name) if "scal" in symv.name(): @@ -100,16 +96,12 @@ def schedule_interleave_symv_row_major_stride_1( exo_ssymv_scal_y_stride_any = rename( exo_ssymv_scal_y_stride_any, exo_ssymv_scal_y_stride_any.name() + "_stride_any" ) -exo_ssymv_row_major_Upper_stride_any = specialize_symv( - symv_row_major_Upper_template, "f32" -) +exo_ssymv_row_major_Upper_stride_any = specialize_symv(symv_row_major_Upper, "f32") exo_ssymv_row_major_Upper_stride_any = rename( exo_ssymv_row_major_Upper_stride_any, exo_ssymv_row_major_Upper_stride_any.name() + "_stride_any", ) -exo_ssymv_row_major_Lower_stride_any = specialize_symv( - symv_row_major_Lower_template, "f32" -) +exo_ssymv_row_major_Lower_stride_any = specialize_symv(symv_row_major_Lower, "f32") exo_ssymv_row_major_Lower_stride_any = rename( exo_ssymv_row_major_Lower_stride_any, exo_ssymv_row_major_Lower_stride_any.name() + "_stride_any", @@ -124,7 +116,7 @@ def schedule_interleave_symv_row_major_stride_1( ] exo_ssymv_row_major_Upper_stride_1 = schedule_interleave_symv_row_major_stride_1( - symv_row_major_Upper_template, + symv_row_major_Upper, C.Machine.f32_vec_width, 1, C.Machine.mem_type, @@ -132,7 +124,7 @@ def schedule_interleave_symv_row_major_stride_1( "f32", ) exo_ssymv_row_major_Lower_stride_1 = schedule_interleave_symv_row_major_stride_1( - symv_row_major_Lower_template, + symv_row_major_Lower, C.Machine.f32_vec_width, 1, C.Machine.mem_type, @@ -151,16 +143,12 @@ def schedule_interleave_symv_row_major_stride_1( exo_dsymv_scal_y_stride_any = rename( exo_dsymv_scal_y_stride_any, exo_dsymv_scal_y_stride_any.name() + "_stride_any" ) -exo_dsymv_row_major_Upper_stride_any = specialize_symv( - symv_row_major_Upper_template, "f64" -) +exo_dsymv_row_major_Upper_stride_any = specialize_symv(symv_row_major_Upper, "f64") exo_dsymv_row_major_Upper_stride_any = rename( exo_dsymv_row_major_Upper_stride_any, exo_dsymv_row_major_Upper_stride_any.name() + "_stride_any", ) -exo_dsymv_row_major_Lower_stride_any = specialize_symv( - symv_row_major_Lower_template, "f64" -) +exo_dsymv_row_major_Lower_stride_any = specialize_symv(symv_row_major_Lower, "f64") exo_dsymv_row_major_Lower_stride_any = rename( exo_dsymv_row_major_Lower_stride_any, exo_dsymv_row_major_Lower_stride_any.name() + "_stride_any", @@ -176,7 +164,7 @@ def schedule_interleave_symv_row_major_stride_1( ] exo_dsymv_row_major_Upper_stride_1 = schedule_interleave_symv_row_major_stride_1( - symv_row_major_Upper_template, + symv_row_major_Upper, C.Machine.f32_vec_width // 2, 1, C.Machine.mem_type, @@ -184,7 +172,7 @@ def schedule_interleave_symv_row_major_stride_1( "f64", ) exo_dsymv_row_major_Lower_stride_1 = schedule_interleave_symv_row_major_stride_1( - symv_row_major_Lower_template, + symv_row_major_Lower, C.Machine.f32_vec_width // 2, 1, C.Machine.mem_type, diff --git a/src/level2/syr.py b/src/level2/syr.py index 622e9312..faf2a99e 100644 --- a/src/level2/syr.py +++ b/src/level2/syr.py @@ -13,7 +13,7 @@ ### EXO_LOC ALGORITHM START ### @proc -def syr_row_major_Upper_template(n: size, alpha: R, x: [R][n], A: [R][n, n]): +def syr_row_major_Upper(n: size, alpha: R, x: [R][n], A: [R][n, n]): assert stride(A, 1) == 1 for i in seq(0, n): @@ -22,7 +22,7 @@ def syr_row_major_Upper_template(n: size, alpha: R, x: [R][n], A: [R][n, n]): @proc -def syr_row_major_Lower_template(n: size, alpha: R, x: [R][n], A: [R][n, n]): +def syr_row_major_Lower(n: size, alpha: R, x: [R][n], A: [R][n, n]): assert stride(A, 1) == 1 for i in seq(0, n): @@ -37,7 +37,7 @@ def syr_row_major_Lower_template(n: size, alpha: R, x: [R][n], A: [R][n, n]): def specialize_syr(syr, precision): prefix = "s" if precision == "f32" else "d" name = syr.name() - name = name.replace("_template", "") + name = name.replace("", "") specialized = rename(syr, "exo_" + prefix + name) args = ["alpha", "x", "A"] @@ -70,16 +70,12 @@ def schedule_interleave_syr_row_major_stride_1( # Generate specialized kernels for f32 precision ################################################# -exo_ssyr_row_major_Upper_stride_any = specialize_syr( - syr_row_major_Upper_template, "f32" -) +exo_ssyr_row_major_Upper_stride_any = specialize_syr(syr_row_major_Upper, "f32") exo_ssyr_row_major_Upper_stride_any = rename( exo_ssyr_row_major_Upper_stride_any, exo_ssyr_row_major_Upper_stride_any.name() + "_stride_any", ) -exo_ssyr_row_major_Lower_stride_any = specialize_syr( - syr_row_major_Lower_template, "f32" -) +exo_ssyr_row_major_Lower_stride_any = specialize_syr(syr_row_major_Lower, "f32") exo_ssyr_row_major_Lower_stride_any = rename( exo_ssyr_row_major_Lower_stride_any, exo_ssyr_row_major_Lower_stride_any.name() + "_stride_any", @@ -94,7 +90,7 @@ def schedule_interleave_syr_row_major_stride_1( ] exo_ssyr_row_major_Upper_stride_1 = schedule_interleave_syr_row_major_stride_1( - syr_row_major_Upper_template, + syr_row_major_Upper, C.Machine.f32_vec_width, 1, C.Machine.mem_type, @@ -102,7 +98,7 @@ def schedule_interleave_syr_row_major_stride_1( "f32", ) exo_ssyr_row_major_Lower_stride_1 = schedule_interleave_syr_row_major_stride_1( - syr_row_major_Lower_template, + syr_row_major_Lower, C.Machine.f32_vec_width, 1, C.Machine.mem_type, @@ -114,16 +110,12 @@ def schedule_interleave_syr_row_major_stride_1( # Generate specialized kernels for f64 precision ################################################# -exo_dsyr_row_major_Upper_stride_any = specialize_syr( - syr_row_major_Upper_template, "f64" -) +exo_dsyr_row_major_Upper_stride_any = specialize_syr(syr_row_major_Upper, "f64") exo_dsyr_row_major_Upper_stride_any = rename( exo_dsyr_row_major_Upper_stride_any, exo_dsyr_row_major_Upper_stride_any.name() + "_stride_any", ) -exo_dsyr_row_major_Lower_stride_any = specialize_syr( - syr_row_major_Lower_template, "f64" -) +exo_dsyr_row_major_Lower_stride_any = specialize_syr(syr_row_major_Lower, "f64") exo_dsyr_row_major_Lower_stride_any = rename( exo_dsyr_row_major_Lower_stride_any, exo_dsyr_row_major_Lower_stride_any.name() + "_stride_any", @@ -139,7 +131,7 @@ def schedule_interleave_syr_row_major_stride_1( ] exo_dsyr_row_major_Upper_stride_1 = schedule_interleave_syr_row_major_stride_1( - syr_row_major_Upper_template, + syr_row_major_Upper, C.Machine.f32_vec_width // 2, 1, C.Machine.mem_type, @@ -147,7 +139,7 @@ def schedule_interleave_syr_row_major_stride_1( "f64", ) exo_dsyr_row_major_Lower_stride_1 = schedule_interleave_syr_row_major_stride_1( - syr_row_major_Lower_template, + syr_row_major_Lower, C.Machine.f32_vec_width // 2, 1, C.Machine.mem_type, diff --git a/src/level2/syr2.py b/src/level2/syr2.py index 520c5353..cbdb6ea8 100644 --- a/src/level2/syr2.py +++ b/src/level2/syr2.py @@ -11,9 +11,7 @@ ### EXO_LOC ALGORITHM START ### @proc -def syr2_row_major_Upper_template( - n: size, alpha: R, x: [R][n], y: [R][n], A: [R][n, n] -): +def syr2_row_major_Upper(n: size, alpha: R, x: [R][n], y: [R][n], A: [R][n, n]): assert stride(A, 1) == 1 for i in seq(0, n): @@ -22,9 +20,7 @@ def syr2_row_major_Upper_template( @proc -def syr2_row_major_Lower_template( - n: size, alpha: R, x: [R][n], y: [R][n], A: [R][n, n] -): +def syr2_row_major_Lower(n: size, alpha: R, x: [R][n], y: [R][n], A: [R][n, n]): assert stride(A, 1) == 1 for i in seq(0, n): @@ -39,7 +35,7 @@ def syr2_row_major_Lower_template( def specialize_syr2(syr2, precision): prefix = "s" if precision == "f32" else "d" name = syr2.name() - name = name.replace("_template", "") + name = name.replace("", "") specialized = rename(syr2, "exo_" + prefix + name) args = ["alpha", "x", "y", "A"] @@ -69,16 +65,12 @@ def schedule_interleave_syr2_row_major_stride_1( # Generate specialized kernels for f32 precision ################################################# -exo_ssyr2_row_major_Upper_stride_any = specialize_syr2( - syr2_row_major_Upper_template, "f32" -) +exo_ssyr2_row_major_Upper_stride_any = specialize_syr2(syr2_row_major_Upper, "f32") exo_ssyr2_row_major_Upper_stride_any = rename( exo_ssyr2_row_major_Upper_stride_any, exo_ssyr2_row_major_Upper_stride_any.name() + "_stride_any", ) -exo_ssyr2_row_major_Lower_stride_any = specialize_syr2( - syr2_row_major_Lower_template, "f32" -) +exo_ssyr2_row_major_Lower_stride_any = specialize_syr2(syr2_row_major_Lower, "f32") exo_ssyr2_row_major_Lower_stride_any = rename( exo_ssyr2_row_major_Lower_stride_any, exo_ssyr2_row_major_Lower_stride_any.name() + "_stride_any", @@ -93,7 +85,7 @@ def schedule_interleave_syr2_row_major_stride_1( ] exo_ssyr2_row_major_Upper_stride_1 = schedule_interleave_syr2_row_major_stride_1( - syr2_row_major_Upper_template, + syr2_row_major_Upper, C.Machine.f32_vec_width, 1, C.Machine.mem_type, @@ -101,7 +93,7 @@ def schedule_interleave_syr2_row_major_stride_1( "f32", ) exo_ssyr2_row_major_Lower_stride_1 = schedule_interleave_syr2_row_major_stride_1( - syr2_row_major_Lower_template, + syr2_row_major_Lower, C.Machine.f32_vec_width, 1, C.Machine.mem_type, @@ -113,16 +105,12 @@ def schedule_interleave_syr2_row_major_stride_1( # Generate specialized kernels for f64 precision ################################################# -exo_dsyr2_row_major_Upper_stride_any = specialize_syr2( - syr2_row_major_Upper_template, "f64" -) +exo_dsyr2_row_major_Upper_stride_any = specialize_syr2(syr2_row_major_Upper, "f64") exo_dsyr2_row_major_Upper_stride_any = rename( exo_dsyr2_row_major_Upper_stride_any, exo_dsyr2_row_major_Upper_stride_any.name() + "_stride_any", ) -exo_dsyr2_row_major_Lower_stride_any = specialize_syr2( - syr2_row_major_Lower_template, "f64" -) +exo_dsyr2_row_major_Lower_stride_any = specialize_syr2(syr2_row_major_Lower, "f64") exo_dsyr2_row_major_Lower_stride_any = rename( exo_dsyr2_row_major_Lower_stride_any, exo_dsyr2_row_major_Lower_stride_any.name() + "_stride_any", @@ -138,7 +126,7 @@ def schedule_interleave_syr2_row_major_stride_1( ] exo_dsyr2_row_major_Upper_stride_1 = schedule_interleave_syr2_row_major_stride_1( - syr2_row_major_Upper_template, + syr2_row_major_Upper, C.Machine.f32_vec_width // 2, 1, C.Machine.mem_type, @@ -146,7 +134,7 @@ def schedule_interleave_syr2_row_major_stride_1( "f64", ) exo_dsyr2_row_major_Lower_stride_1 = schedule_interleave_syr2_row_major_stride_1( - syr2_row_major_Lower_template, + syr2_row_major_Lower, C.Machine.f32_vec_width // 2, 1, C.Machine.mem_type, diff --git a/src/level2/tbmv.py b/src/level2/tbmv.py index ed12f446..1ae635f5 100644 --- a/src/level2/tbmv.py +++ b/src/level2/tbmv.py @@ -10,7 +10,7 @@ @proc -def tbmv_row_major_Upper_NonTrans_template( +def tbmv_row_major_Upper_NonTrans( n: size, k: size, x: [R][n], A: [R][n, k + 1], Diag: size ): assert stride(A, 1) == 1 @@ -46,7 +46,7 @@ def tbmv_row_major_Upper_NonTrans_template( @proc -def tbmv_row_major_Lower_NonTrans_template( +def tbmv_row_major_Lower_NonTrans( n: size, k: size, x: [R][n], A: [R][n, k + 1], Diag: size ): assert stride(A, 1) == 1 @@ -82,7 +82,7 @@ def tbmv_row_major_Lower_NonTrans_template( @proc -def tbmv_row_major_Upper_Trans_template( +def tbmv_row_major_Upper_Trans( n: size, k: size, x: [R][n], A: [R][n, k + 1], Diag: size ): assert stride(A, 1) == 1 @@ -115,7 +115,7 @@ def tbmv_row_major_Upper_Trans_template( @proc -def tbmv_row_major_Lower_Trans_template( +def tbmv_row_major_Lower_Trans( n: size, k: size, x: [R][n], A: [R][n, k + 1], Diag: size ): assert stride(A, 1) == 1 @@ -156,12 +156,11 @@ def tbmv_row_major_Lower_Trans_template( def specialize_tbmv(tbmv, precision): prefix = "s" if precision == "f32" else "d" name = tbmv.name() - name = name.replace("_template", "") specialized = rename(tbmv, "exo_" + prefix + name) args = ["x", "A"] - if "_Trans_" in tbmv.name(): + if "_Trans" in tbmv.name(): args.append("xRes") else: args.append("dot") @@ -194,28 +193,28 @@ def schedule_interleave_tbmv_row_major_stride_1( ################################################# exo_stbmv_row_major_Upper_NonTrans_stride_any = specialize_tbmv( - tbmv_row_major_Upper_NonTrans_template, "f32" + tbmv_row_major_Upper_NonTrans, "f32" ) exo_stbmv_row_major_Upper_NonTrans_stride_any = rename( exo_stbmv_row_major_Upper_NonTrans_stride_any, exo_stbmv_row_major_Upper_NonTrans_stride_any.name() + "_stride_any", ) exo_stbmv_row_major_Lower_NonTrans_stride_any = specialize_tbmv( - tbmv_row_major_Lower_NonTrans_template, "f32" + tbmv_row_major_Lower_NonTrans, "f32" ) exo_stbmv_row_major_Lower_NonTrans_stride_any = rename( exo_stbmv_row_major_Lower_NonTrans_stride_any, exo_stbmv_row_major_Lower_NonTrans_stride_any.name() + "_stride_any", ) exo_stbmv_row_major_Upper_Trans_stride_any = specialize_tbmv( - tbmv_row_major_Upper_Trans_template, "f32" + tbmv_row_major_Upper_Trans, "f32" ) exo_stbmv_row_major_Upper_Trans_stride_any = rename( exo_stbmv_row_major_Upper_Trans_stride_any, exo_stbmv_row_major_Upper_Trans_stride_any.name() + "_stride_any", ) exo_stbmv_row_major_Lower_Trans_stride_any = specialize_tbmv( - tbmv_row_major_Lower_Trans_template, "f32" + tbmv_row_major_Lower_Trans, "f32" ) exo_stbmv_row_major_Lower_Trans_stride_any = rename( exo_stbmv_row_major_Lower_Trans_stride_any, @@ -233,7 +232,7 @@ def schedule_interleave_tbmv_row_major_stride_1( exo_stbmv_row_major_Upper_NonTrans_stride_1 = ( schedule_interleave_tbmv_row_major_stride_1( - tbmv_row_major_Upper_NonTrans_template, + tbmv_row_major_Upper_NonTrans, C.Machine.f32_vec_width, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -243,7 +242,7 @@ def schedule_interleave_tbmv_row_major_stride_1( ) exo_stbmv_row_major_Lower_NonTrans_stride_1 = ( schedule_interleave_tbmv_row_major_stride_1( - tbmv_row_major_Lower_NonTrans_template, + tbmv_row_major_Lower_NonTrans, C.Machine.f32_vec_width, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -252,7 +251,7 @@ def schedule_interleave_tbmv_row_major_stride_1( ) ) exo_stbmv_row_major_Upper_Trans_stride_1 = schedule_interleave_tbmv_row_major_stride_1( - tbmv_row_major_Upper_Trans_template, + tbmv_row_major_Upper_Trans, C.Machine.f32_vec_width, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -260,7 +259,7 @@ def schedule_interleave_tbmv_row_major_stride_1( "f32", ) exo_stbmv_row_major_Lower_Trans_stride_1 = schedule_interleave_tbmv_row_major_stride_1( - tbmv_row_major_Lower_Trans_template, + tbmv_row_major_Lower_Trans, C.Machine.f32_vec_width, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -273,28 +272,28 @@ def schedule_interleave_tbmv_row_major_stride_1( ################################################# exo_dtbmv_row_major_Upper_NonTrans_stride_any = specialize_tbmv( - tbmv_row_major_Upper_NonTrans_template, "f64" + tbmv_row_major_Upper_NonTrans, "f64" ) exo_dtbmv_row_major_Upper_NonTrans_stride_any = rename( exo_dtbmv_row_major_Upper_NonTrans_stride_any, exo_dtbmv_row_major_Upper_NonTrans_stride_any.name() + "_stride_any", ) exo_dtbmv_row_major_Lower_NonTrans_stride_any = specialize_tbmv( - tbmv_row_major_Lower_NonTrans_template, "f64" + tbmv_row_major_Lower_NonTrans, "f64" ) exo_dtbmv_row_major_Lower_NonTrans_stride_any = rename( exo_dtbmv_row_major_Lower_NonTrans_stride_any, exo_dtbmv_row_major_Lower_NonTrans_stride_any.name() + "_stride_any", ) exo_dtbmv_row_major_Upper_Trans_stride_any = specialize_tbmv( - tbmv_row_major_Upper_Trans_template, "f64" + tbmv_row_major_Upper_Trans, "f64" ) exo_dtbmv_row_major_Upper_Trans_stride_any = rename( exo_dtbmv_row_major_Upper_Trans_stride_any, exo_dtbmv_row_major_Upper_Trans_stride_any.name() + "_stride_any", ) exo_dtbmv_row_major_Lower_Trans_stride_any = specialize_tbmv( - tbmv_row_major_Lower_Trans_template, "f64" + tbmv_row_major_Lower_Trans, "f64" ) exo_dtbmv_row_major_Lower_Trans_stride_any = rename( exo_dtbmv_row_major_Lower_Trans_stride_any, @@ -312,7 +311,7 @@ def schedule_interleave_tbmv_row_major_stride_1( exo_dtbmv_row_major_Upper_NonTrans_stride_1 = ( schedule_interleave_tbmv_row_major_stride_1( - tbmv_row_major_Upper_NonTrans_template, + tbmv_row_major_Upper_NonTrans, C.Machine.f32_vec_width // 2, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -322,7 +321,7 @@ def schedule_interleave_tbmv_row_major_stride_1( ) exo_dtbmv_row_major_Lower_NonTrans_stride_1 = ( schedule_interleave_tbmv_row_major_stride_1( - tbmv_row_major_Lower_NonTrans_template, + tbmv_row_major_Lower_NonTrans, C.Machine.f32_vec_width // 2, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -331,7 +330,7 @@ def schedule_interleave_tbmv_row_major_stride_1( ) ) exo_dtbmv_row_major_Upper_Trans_stride_1 = schedule_interleave_tbmv_row_major_stride_1( - tbmv_row_major_Upper_Trans_template, + tbmv_row_major_Upper_Trans, C.Machine.f32_vec_width // 2, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -339,7 +338,7 @@ def schedule_interleave_tbmv_row_major_stride_1( "f64", ) exo_dtbmv_row_major_Lower_Trans_stride_1 = schedule_interleave_tbmv_row_major_stride_1( - tbmv_row_major_Lower_Trans_template, + tbmv_row_major_Lower_Trans, C.Machine.f32_vec_width // 2, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, diff --git a/src/level2/tbsv.py b/src/level2/tbsv.py index bdad1a8e..a749e3ea 100644 --- a/src/level2/tbsv.py +++ b/src/level2/tbsv.py @@ -10,7 +10,7 @@ @proc -def tbsv_row_major_Upper_NonTrans_template( +def tbsv_row_major_Upper_NonTrans( n: size, k: size, x: [R][n], A: [R][n, n], Diag: size ): assert stride(A, 1) == 1 @@ -48,7 +48,7 @@ def tbsv_row_major_Upper_NonTrans_template( @proc -def tbsv_row_major_Lower_NonTrans_template( +def tbsv_row_major_Lower_NonTrans( n: size, k: size, x: [R][n], A: [R][n, n], Diag: size ): assert stride(A, 1) == 1 @@ -86,9 +86,7 @@ def tbsv_row_major_Lower_NonTrans_template( @proc -def tbsv_row_major_Upper_Trans_template( - n: size, k: size, x: [R][n], A: [R][n, n], Diag: size -): +def tbsv_row_major_Upper_Trans(n: size, k: size, x: [R][n], A: [R][n, n], Diag: size): assert stride(A, 1) == 1 assert k <= n - 1 @@ -112,9 +110,7 @@ def tbsv_row_major_Upper_Trans_template( @proc -def tbsv_row_major_Lower_Trans_template( - n: size, k: size, x: [R][n], A: [R][n, n], Diag: size -): +def tbsv_row_major_Lower_Trans(n: size, k: size, x: [R][n], A: [R][n, n], Diag: size): assert stride(A, 1) == 1 assert k <= n - 1 @@ -140,7 +136,7 @@ def tbsv_row_major_Lower_Trans_template( def specialize_tbsv(tbsv, precision): prefix = "s" if precision == "f32" else "d" name = tbsv.name() - name = name.replace("_template", "") + name = name.replace("", "") specialized = rename(tbsv, "exo_" + prefix + name) args = ["x", "A", "dot", "pivot"] @@ -179,28 +175,28 @@ def schedule_interleave_tbsv_row_major_stride_1( ################################################# exo_stbsv_row_major_Upper_NonTrans_stride_any = specialize_tbsv( - tbsv_row_major_Upper_NonTrans_template, "f32" + tbsv_row_major_Upper_NonTrans, "f32" ) exo_stbsv_row_major_Upper_NonTrans_stride_any = rename( exo_stbsv_row_major_Upper_NonTrans_stride_any, exo_stbsv_row_major_Upper_NonTrans_stride_any.name() + "_stride_any", ) exo_stbsv_row_major_Lower_NonTrans_stride_any = specialize_tbsv( - tbsv_row_major_Lower_NonTrans_template, "f32" + tbsv_row_major_Lower_NonTrans, "f32" ) exo_stbsv_row_major_Lower_NonTrans_stride_any = rename( exo_stbsv_row_major_Lower_NonTrans_stride_any, exo_stbsv_row_major_Lower_NonTrans_stride_any.name() + "_stride_any", ) exo_stbsv_row_major_Upper_Trans_stride_any = specialize_tbsv( - tbsv_row_major_Upper_Trans_template, "f32" + tbsv_row_major_Upper_Trans, "f32" ) exo_stbsv_row_major_Upper_Trans_stride_any = rename( exo_stbsv_row_major_Upper_Trans_stride_any, exo_stbsv_row_major_Upper_Trans_stride_any.name() + "_stride_any", ) exo_stbsv_row_major_Lower_Trans_stride_any = specialize_tbsv( - tbsv_row_major_Lower_Trans_template, "f32" + tbsv_row_major_Lower_Trans, "f32" ) exo_stbsv_row_major_Lower_Trans_stride_any = rename( exo_stbsv_row_major_Lower_Trans_stride_any, @@ -218,7 +214,7 @@ def schedule_interleave_tbsv_row_major_stride_1( exo_stbsv_row_major_Upper_NonTrans_stride_1 = ( schedule_interleave_tbsv_row_major_stride_1( - tbsv_row_major_Upper_NonTrans_template, + tbsv_row_major_Upper_NonTrans, C.Machine.f32_vec_width, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -228,7 +224,7 @@ def schedule_interleave_tbsv_row_major_stride_1( ) exo_stbsv_row_major_Lower_NonTrans_stride_1 = ( schedule_interleave_tbsv_row_major_stride_1( - tbsv_row_major_Lower_NonTrans_template, + tbsv_row_major_Lower_NonTrans, C.Machine.f32_vec_width, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -242,28 +238,28 @@ def schedule_interleave_tbsv_row_major_stride_1( ################################################# exo_dtbsv_row_major_Upper_NonTrans_stride_any = specialize_tbsv( - tbsv_row_major_Upper_NonTrans_template, "f64" + tbsv_row_major_Upper_NonTrans, "f64" ) exo_dtbsv_row_major_Upper_NonTrans_stride_any = rename( exo_dtbsv_row_major_Upper_NonTrans_stride_any, exo_dtbsv_row_major_Upper_NonTrans_stride_any.name() + "_stride_any", ) exo_dtbsv_row_major_Lower_NonTrans_stride_any = specialize_tbsv( - tbsv_row_major_Lower_NonTrans_template, "f64" + tbsv_row_major_Lower_NonTrans, "f64" ) exo_dtbsv_row_major_Lower_NonTrans_stride_any = rename( exo_dtbsv_row_major_Lower_NonTrans_stride_any, exo_dtbsv_row_major_Lower_NonTrans_stride_any.name() + "_stride_any", ) exo_dtbsv_row_major_Upper_Trans_stride_any = specialize_tbsv( - tbsv_row_major_Upper_Trans_template, "f64" + tbsv_row_major_Upper_Trans, "f64" ) exo_dtbsv_row_major_Upper_Trans_stride_any = rename( exo_dtbsv_row_major_Upper_Trans_stride_any, exo_dtbsv_row_major_Upper_Trans_stride_any.name() + "_stride_any", ) exo_dtbsv_row_major_Lower_Trans_stride_any = specialize_tbsv( - tbsv_row_major_Lower_Trans_template, "f64" + tbsv_row_major_Lower_Trans, "f64" ) exo_dtbsv_row_major_Lower_Trans_stride_any = rename( exo_dtbsv_row_major_Lower_Trans_stride_any, @@ -281,7 +277,7 @@ def schedule_interleave_tbsv_row_major_stride_1( exo_dtbsv_row_major_Upper_NonTrans_stride_1 = ( schedule_interleave_tbsv_row_major_stride_1( - tbsv_row_major_Upper_NonTrans_template, + tbsv_row_major_Upper_NonTrans, C.Machine.f32_vec_width // 2, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -291,7 +287,7 @@ def schedule_interleave_tbsv_row_major_stride_1( ) exo_dtbsv_row_major_Lower_NonTrans_stride_1 = ( schedule_interleave_tbsv_row_major_stride_1( - tbsv_row_major_Lower_NonTrans_template, + tbsv_row_major_Lower_NonTrans, C.Machine.f32_vec_width // 2, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, diff --git a/src/level2/trmv.py b/src/level2/trmv.py index ce1cba29..c2adb07c 100644 --- a/src/level2/trmv.py +++ b/src/level2/trmv.py @@ -1,20 +1,14 @@ from __future__ import annotations from exo import * -from exo.libs.memories import DRAM_STATIC -from exo.platforms.x86 import * -from exo.syntax import * -from exo.stdlib.scheduling import * -import exo_blas_config as C -from composed_schedules import * from blaslib import * from codegen_helpers import * ### EXO_LOC ALGORITHM START ### @proc -def trmv_rm_un_template(Diag: index, n: size, x: [R][n], A: [R][n, n]): +def trmv_rm_un(Diag: index, n: size, x: [R][n], A: [R][n, n]): assert stride(A, 1) == 1 xCopy: R[n] @@ -33,7 +27,7 @@ def trmv_rm_un_template(Diag: index, n: size, x: [R][n], A: [R][n, n]): @proc -def trmv_rm_ln_template(Diag: index, n: size, x: [R][n], A: [R][n, n]): +def trmv_rm_ln(Diag: index, n: size, x: [R][n], A: [R][n, n]): assert stride(A, 1) == 1 xCopy: R[n] @@ -52,7 +46,7 @@ def trmv_rm_ln_template(Diag: index, n: size, x: [R][n], A: [R][n, n]): @proc -def trmv_rm_ut_template(Diag: index, n: size, x: [R][n], A: [R][n, n]): +def trmv_rm_ut(Diag: index, n: size, x: [R][n], A: [R][n, n]): assert stride(A, 1) == 1 xCopy: R[n] @@ -72,7 +66,7 @@ def trmv_rm_ut_template(Diag: index, n: size, x: [R][n], A: [R][n, n]): @proc -def trmv_rm_lt_template(Diag: index, n: size, x: [R][n], A: [R][n, n]): +def trmv_rm_lt(Diag: index, n: size, x: [R][n], A: [R][n, n]): assert stride(A, 1) == 1 xCopy: R[n] @@ -96,21 +90,7 @@ def trmv_rm_lt_template(Diag: index, n: size, x: [R][n], A: [R][n, n]): ### EXO_LOC SCHEDULE START ### -template_sched_list = [ - trmv_rm_un_template, - trmv_rm_ln_template, - trmv_rm_ut_template, - trmv_rm_lt_template, -] - -for precision in ("f32", "f64"): - for template in template_sched_list: - proc_stride_any = generate_stride_any_proc(template, precision) - export_exo_proc(globals(), proc_stride_any) - proc_stride_1 = generate_stride_1_proc(template, precision) - proc_stride_1 = optimize_level_2( - proc_stride_1, proc_stride_1.find_loop("i"), precision, C.Machine, 4, 2 - ) - export_exo_proc(globals(), proc_stride_1) +for proc in trmv_rm_un, trmv_rm_ln, trmv_rm_ut, trmv_rm_lt: + variants_generator(optimize_level_2)(proc, "i", 4, 2, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level2/trsv.py b/src/level2/trsv.py index b3d7183b..d57d5753 100644 --- a/src/level2/trsv.py +++ b/src/level2/trsv.py @@ -10,9 +10,7 @@ @proc -def trsv_row_major_Upper_NonTrans_template( - n: size, x: [R][n], A: [R][n, n], Diag: size -): +def trsv_row_major_Upper_NonTrans(n: size, x: [R][n], A: [R][n, n], Diag: size): assert stride(A, 1) == 1 for i in seq(0, n): @@ -35,9 +33,7 @@ def trsv_row_major_Upper_NonTrans_template( @proc -def trsv_row_major_Lower_NonTrans_template( - n: size, x: [R][n], A: [R][n, n], Diag: size -): +def trsv_row_major_Lower_NonTrans(n: size, x: [R][n], A: [R][n, n], Diag: size): assert stride(A, 1) == 1 for i in seq(0, n): @@ -59,7 +55,7 @@ def trsv_row_major_Lower_NonTrans_template( @proc -def trsv_row_major_Upper_Trans_template(n: size, x: [R][n], A: [R][n, n], Diag: size): +def trsv_row_major_Upper_Trans(n: size, x: [R][n], A: [R][n, n], Diag: size): assert stride(A, 1) == 1 for i in seq(0, n): @@ -81,7 +77,7 @@ def trsv_row_major_Upper_Trans_template(n: size, x: [R][n], A: [R][n, n], Diag: @proc -def trsv_row_major_Lower_Trans_template(n: size, x: [R][n], A: [R][n, n], Diag: size): +def trsv_row_major_Lower_Trans(n: size, x: [R][n], A: [R][n, n], Diag: size): assert stride(A, 1) == 1 for i in seq(0, n): @@ -106,7 +102,7 @@ def trsv_row_major_Lower_Trans_template(n: size, x: [R][n], A: [R][n, n], Diag: def specialize_trsv(trsv, precision): prefix = "s" if precision == "f32" else "d" name = trsv.name() - name = name.replace("_template", "") + name = name.replace("", "") specialized = rename(trsv, "exo_" + prefix + name) args = ["x", "A", "dot", "pivot"] @@ -138,28 +134,28 @@ def schedule_interleave_trsv_row_major_stride_1( ################################################# exo_strsv_row_major_Upper_NonTrans_stride_any = specialize_trsv( - trsv_row_major_Upper_NonTrans_template, "f32" + trsv_row_major_Upper_NonTrans, "f32" ) exo_strsv_row_major_Upper_NonTrans_stride_any = rename( exo_strsv_row_major_Upper_NonTrans_stride_any, exo_strsv_row_major_Upper_NonTrans_stride_any.name() + "_stride_any", ) exo_strsv_row_major_Lower_NonTrans_stride_any = specialize_trsv( - trsv_row_major_Lower_NonTrans_template, "f32" + trsv_row_major_Lower_NonTrans, "f32" ) exo_strsv_row_major_Lower_NonTrans_stride_any = rename( exo_strsv_row_major_Lower_NonTrans_stride_any, exo_strsv_row_major_Lower_NonTrans_stride_any.name() + "_stride_any", ) exo_strsv_row_major_Upper_Trans_stride_any = specialize_trsv( - trsv_row_major_Upper_Trans_template, "f32" + trsv_row_major_Upper_Trans, "f32" ) exo_strsv_row_major_Upper_Trans_stride_any = rename( exo_strsv_row_major_Upper_Trans_stride_any, exo_strsv_row_major_Upper_Trans_stride_any.name() + "_stride_any", ) exo_strsv_row_major_Lower_Trans_stride_any = specialize_trsv( - trsv_row_major_Lower_Trans_template, "f32" + trsv_row_major_Lower_Trans, "f32" ) exo_strsv_row_major_Lower_Trans_stride_any = rename( exo_strsv_row_major_Lower_Trans_stride_any, @@ -177,7 +173,7 @@ def schedule_interleave_trsv_row_major_stride_1( exo_strsv_row_major_Upper_NonTrans_stride_1 = ( schedule_interleave_trsv_row_major_stride_1( - trsv_row_major_Upper_NonTrans_template, + trsv_row_major_Upper_NonTrans, C.Machine.f32_vec_width, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -187,7 +183,7 @@ def schedule_interleave_trsv_row_major_stride_1( ) exo_strsv_row_major_Lower_NonTrans_stride_1 = ( schedule_interleave_trsv_row_major_stride_1( - trsv_row_major_Lower_NonTrans_template, + trsv_row_major_Lower_NonTrans, C.Machine.f32_vec_width, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -201,28 +197,28 @@ def schedule_interleave_trsv_row_major_stride_1( ################################################# exo_dtrsv_row_major_Upper_NonTrans_stride_any = specialize_trsv( - trsv_row_major_Upper_NonTrans_template, "f64" + trsv_row_major_Upper_NonTrans, "f64" ) exo_dtrsv_row_major_Upper_NonTrans_stride_any = rename( exo_dtrsv_row_major_Upper_NonTrans_stride_any, exo_dtrsv_row_major_Upper_NonTrans_stride_any.name() + "_stride_any", ) exo_dtrsv_row_major_Lower_NonTrans_stride_any = specialize_trsv( - trsv_row_major_Lower_NonTrans_template, "f64" + trsv_row_major_Lower_NonTrans, "f64" ) exo_dtrsv_row_major_Lower_NonTrans_stride_any = rename( exo_dtrsv_row_major_Lower_NonTrans_stride_any, exo_dtrsv_row_major_Lower_NonTrans_stride_any.name() + "_stride_any", ) exo_dtrsv_row_major_Upper_Trans_stride_any = specialize_trsv( - trsv_row_major_Upper_Trans_template, "f64" + trsv_row_major_Upper_Trans, "f64" ) exo_dtrsv_row_major_Upper_Trans_stride_any = rename( exo_dtrsv_row_major_Upper_Trans_stride_any, exo_dtrsv_row_major_Upper_Trans_stride_any.name() + "_stride_any", ) exo_dtrsv_row_major_Lower_Trans_stride_any = specialize_trsv( - trsv_row_major_Lower_Trans_template, "f64" + trsv_row_major_Lower_Trans, "f64" ) exo_dtrsv_row_major_Lower_Trans_stride_any = rename( exo_dtrsv_row_major_Lower_Trans_stride_any, @@ -240,7 +236,7 @@ def schedule_interleave_trsv_row_major_stride_1( exo_dtrsv_row_major_Upper_NonTrans_stride_1 = ( schedule_interleave_trsv_row_major_stride_1( - trsv_row_major_Upper_NonTrans_template, + trsv_row_major_Upper_NonTrans, C.Machine.f32_vec_width // 2, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, @@ -250,7 +246,7 @@ def schedule_interleave_trsv_row_major_stride_1( ) exo_dtrsv_row_major_Lower_NonTrans_stride_1 = ( schedule_interleave_trsv_row_major_stride_1( - trsv_row_major_Lower_NonTrans_template, + trsv_row_major_Lower_NonTrans, C.Machine.f32_vec_width // 2, ROW_INTERLEAVE_FACTOR, C.Machine.mem_type, diff --git a/src/level3/gemm.py b/src/level3/gemm.py index c246e05b..f2e2f558 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -11,16 +11,11 @@ import exo_blas_config as C from composed_schedules import * -from codegen_helpers import ( - generate_stride_any_proc, - export_exo_proc, -) +from codegen_helpers import * @proc -def gemm_matmul_template( - M: size, N: size, K: size, A: [R][M, K], B: [R][K, N], C: [R][M, N] -): +def gemm_matmul(M: size, N: size, K: size, A: [R][M, K], B: [R][K, N], C: [R][M, N]): assert stride(A, 1) == 1 assert stride(B, 1) == 1 assert stride(C, 1) == 1 @@ -229,7 +224,8 @@ def schedule_outer_product_gemm_as_tiles( def schedule_gemm_matmul(gemm, precision): - gemm = generate_stride_any_proc(gemm, precision) + gemm = specialize_precision(gemm, precision) + gemm = generate_stride_any_proc(gemm) i_loop = gemm.find_loop("i") j_loop = i_loop.body()[0] @@ -267,7 +263,7 @@ def schedule_gemm_matmul(gemm, precision): template_sched_list = [ - (gemm_matmul_template, schedule_gemm_matmul), + (gemm_matmul, schedule_gemm_matmul), ] for precision in ("f32",): diff --git a/src/level3/trmm.py b/src/level3/trmm.py index d73c8cc5..df13dc5a 100644 --- a/src/level3/trmm.py +++ b/src/level3/trmm.py @@ -10,7 +10,7 @@ @proc -def trmm_row_major_Left_Upper_NonTrans_template( +def trmm_row_major_Left_Upper_NonTrans( m: size, n: size, alpha: R, A: [R][m, m], B: [R][m, n], Diag: size ): assert stride(A, 1) == 1 @@ -26,7 +26,7 @@ def trmm_row_major_Left_Upper_NonTrans_template( @proc -def trmm_row_major_Left_Lower_NonTrans_template( +def trmm_row_major_Left_Lower_NonTrans( m: size, n: size, alpha: R, A: [R][m, m], B: [R][m, n], Diag: size ): assert stride(A, 1) == 1 @@ -51,7 +51,7 @@ def trmm_row_major_Left_Lower_NonTrans_template( @proc -def trmm_row_major_Left_Upper_Trans_template( +def trmm_row_major_Left_Upper_Trans( m: size, n: size, alpha: R, A: [R][m, m], B: [R][m, n], Diag: size ): assert stride(A, 1) == 1 @@ -76,7 +76,7 @@ def trmm_row_major_Left_Upper_Trans_template( @proc -def trmm_row_major_Left_Lower_Trans_template( +def trmm_row_major_Left_Lower_Trans( m: size, n: size, alpha: R, A: [R][m, m], B: [R][m, n], Diag: size ): assert stride(A, 1) == 1 @@ -94,7 +94,7 @@ def trmm_row_major_Left_Lower_Trans_template( def specialize_trmm(trmm, precision): prefix = "s" if precision == "f32" else "d" name = trmm.name() - name = name.replace("_template", "") + name = name.replace("", "") specialized = rename(trmm, "exo_" + prefix + name) args = ["alpha", "A", "B"] @@ -114,28 +114,28 @@ def specialize_trmm(trmm, precision): ################################################# exo_strmm_row_major_Left_Upper_NonTrans = specialize_trmm( - trmm_row_major_Left_Upper_NonTrans_template, "f32" + trmm_row_major_Left_Upper_NonTrans, "f32" ) exo_strmm_row_major_Left_Upper_NonTrans = rename( exo_strmm_row_major_Left_Upper_NonTrans, exo_strmm_row_major_Left_Upper_NonTrans.name() + "", ) exo_strmm_row_major_Left_Lower_NonTrans = specialize_trmm( - trmm_row_major_Left_Lower_NonTrans_template, "f32" + trmm_row_major_Left_Lower_NonTrans, "f32" ) exo_strmm_row_major_Left_Lower_NonTrans = rename( exo_strmm_row_major_Left_Lower_NonTrans, exo_strmm_row_major_Left_Lower_NonTrans.name() + "", ) exo_strmm_row_major_Left_Upper_Trans = specialize_trmm( - trmm_row_major_Left_Upper_Trans_template, "f32" + trmm_row_major_Left_Upper_Trans, "f32" ) exo_strmm_row_major_Left_Upper_Trans = rename( exo_strmm_row_major_Left_Upper_Trans, exo_strmm_row_major_Left_Upper_Trans.name() + "", ) exo_strmm_row_major_Left_Lower_Trans = specialize_trmm( - trmm_row_major_Left_Lower_Trans_template, "f32" + trmm_row_major_Left_Lower_Trans, "f32" ) exo_strmm_row_major_Left_Lower_Trans = rename( exo_strmm_row_major_Left_Lower_Trans, diff --git a/test/codegen/reference/sha256/avx2.json b/test/codegen/reference/sha256/avx2.json index 382afb6e..66f7feb6 100644 --- a/test/codegen/reference/sha256/avx2.json +++ b/test/codegen/reference/sha256/avx2.json @@ -5,7 +5,7 @@ "exo_dot": "da24d02bccfd273cf39f532ad0561cd3295228a09501db4b4f9a10667412588f", "exo_dsdot": "076de19e797eda245093589de4800eaf7c1079c29a4933c815c76a4a7c38bb8d", "exo_gbmv": "09d9bdd7281966cef0d62f142996cc6c6a0707bc86b5079b851abd1fbf1d9688", - "exo_gemm": "965563d4dff624a32d87aacc9a9164ddcb19455e8c000760f0ed5e4c49b5d5b7", + "exo_gemm": "07e0d06e41e12262473fb297f719016079ec46ac140b956ccfafdcb6a1d17df6", "exo_gemv": "96aa66dc3c19bf5b57b3ad9ab731653badd9190dbd76c18dfa8e5bd827bfebff", "exo_ger": "2a4b78708d6e10e52e1fa62697b9c0c4f26ada31a3a2ce4371db0328df368e56", "exo_iamax": "49c60714c479234683166e5651fbe95ed5a43ecd370a391732769588948cc842",