From d5b18d2da0454d816f70591ecc3cd903893cf4fc Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Mon, 18 Mar 2024 17:30:21 -0400 Subject: [PATCH 01/14] Change tiling in macro kernel --- src/common/higher_order.py | 2 +- src/common/stdlib.py | 95 +++++++++++++++++++------------------- src/level3/gemm.py | 65 ++++++++++++-------------- 3 files changed, 79 insertions(+), 83 deletions(-) diff --git a/src/common/higher_order.py b/src/common/higher_order.py index 9c64aa2..148bc3e 100644 --- a/src/common/higher_order.py +++ b/src/common/higher_order.py @@ -81,7 +81,7 @@ def rewrite(proc, block, subproc_name, *args, rc=False, **kwargs): proc, subproc = extract_subproc(proc, block, subproc_name) subproc_sched = op(subproc, *args, **kwargs) call = proc.forward(block)[0] - proc = call_eqv(proc, call, subproc) + proc = call_eqv(proc, call, subproc_sched) if not rc: return proc return proc, (subproc, subproc_sched) diff --git a/src/common/stdlib.py b/src/common/stdlib.py index 3810639..b314387 100644 --- a/src/common/stdlib.py +++ b/src/common/stdlib.py @@ -585,69 +585,70 @@ def tile_loops_top_down(proc, loop_tile_pairs): return proc, [proc.forward(l) for l in inner_loops] -def tile_loops_bottom_up(proc, loop, tiles, const_allocs=True): +def tile_loops_bottom_up(proc, loop, tiles, const_allocs=True, tail="cut_and_guard"): loop = proc.forward(loop) + + loops = [(loop, tiles[0])] cur_loop = loop - for i in tiles[:-1]: - if not len(cur_loop.body()) == 1: - raise BLAS_SchedulingError("All loop must have a body length of 1") - if not isinstance(cur_loop.body()[0], ForCursor): - raise BLAS_SchedulingError("Did not find a nested loop") - cur_loop = cur_loop.body()[0] - - loops = [] - cur_loop = loop - for i in tiles: - loops.append((cur_loop, i)) - cur_loop = cur_loop.body()[0] + for tile in tiles[1:]: + cur_loop = get_inner_loop(proc, cur_loop) + loops.append((cur_loop, tile)) - def get_depth(loop): - if not isinstance(loop, (ForCursor, IfCursor)): + def get_depth(proc, scope): + scope = proc.forward(scope) + if not isinstance(scope, (ForCursor, IfCursor)): return 0 - return max([get_depth(i) for i in loop.body()]) + 1 + return max([get_depth(proc, i) for i in scope.body()]) + 1 - def push_loop_in(proc, loop, depth, size=None): - loop = proc.forward(loop) - allocs = list(filter_cursors(is_alloc)(proc, loop.body())) + def push_scope_in(proc, scope, depth, size=None): + scope = proc.forward(scope) + allocs = list(filter_cursors(is_alloc)(proc, scope.body())) proc = apply(attempt(parallelize_and_lift_alloc))(proc, allocs) + if const_allocs and size: proc = apply(lambda p, a: resize_dim(p, a, 0, size, 0))(proc, allocs) - loop = proc.forward(loop) - count = len(loop.body()) - for stmt in list(loop.body())[:-1]: - proc = fission(proc, stmt.after()) - loop = proc.forward(loop) - loops = [] - for i in range(count): - loops.append(loop) - loop = loop.next() - for loop in loops: - if get_depth(loop) <= depth: + + scope = proc.forward(scope) + count = len(scope.body()) + scopes = [scope] + for stmt in list(scope.body())[:-1]: + proc, (scope1, scope2) = fission_(proc, stmt.after(), rc=True) + scopes = scopes[:-1] + scopes += [scope1, scope2] + + for scope in scopes: + if get_depth(proc, scope) <= depth: continue - loop = proc.forward(loop) - child = loop.body()[0] - if isinstance(child, ForCursor): - proc = reorder_loops(proc, loop) - forwarded_loop = proc.forward(loop) - elif isinstance(child, IfCursor): + scope = proc.forward(scope) + child = scope.body()[0] + if isinstance(child, (ForCursor, IfCursor)): proc = lift_scope(proc, child) child = proc.forward(child) - forwarded_loop = child.body()[0] + forwarded_scope = child.body()[0] else: continue - proc = push_loop_in(proc, forwarded_loop, depth, size) + proc = push_scope_in(proc, forwarded_scope, depth, size) return proc + guards = 0 for depth, (loop, tile) in enumerate(loops[::-1]): - if tile is not None: - proc, (_, inner, tail) = divide_loop_( - proc, loop, tile, tail="cut_and_guard", rc=True - ) - proc = push_loop_in(proc, inner, depth + 1, tile) - proc = push_loop_in(proc, tail.body()[0], depth + 1, tile) - else: - proc = push_loop_in(proc, loop, depth + 1) - + if tail == "cut_and_guard": + if tile is not None: + proc, (_, inner, tail_l) = divide_loop_( + proc, loop, tile, tail=tail, rc=True + ) + proc = push_scope_in(proc, inner, depth + 1, tile) + proc = push_scope_in(proc, tail_l.body()[0], depth + 1, tile) + else: + proc = push_scope_in(proc, loop, depth + 1) + elif tail == "guard": + if tile is not None: + proc, (_, inner, _) = divide_loop_(proc, loop, tile, tail=tail, rc=True) + proc = push_scope_in(proc, inner.body()[0], guards + 1) + guards += 1 + proc = push_scope_in(proc, inner, depth + guards + 1, tile) + else: + proc = push_scope_in(proc, loop, depth + guards + 1) return proc diff --git a/src/level3/gemm.py b/src/level3/gemm.py index ea76444..2c7069e 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -65,9 +65,7 @@ def rewrite(gemm_uk): gemm_uk = divide_dim(gemm_uk, tile, 1, vw) loops = [init, main_i, axpy] - gemm_uk = apply(optimize_level_2)( - gemm_uk, loops, precision, machine, m_r, n_r // vw, vec_tail="perfect" - ) + return gemm_uk blocks = map(lambda c: c.expand(), gemm_uk.find("C_tile:_", many=True)) @@ -75,9 +73,30 @@ def rewrite(gemm_uk): return gemm_uk -def schedule_macro( - gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac, do_br=False -): +def schedule_compute(gemm_compute, precision, machine, m_r, n_r_fac): + vw = machine.vec_width(precision) + n_r = vw * n_r_fac + + i_loop = gemm_compute.body()[0] + j_loop = get_inner_loop(gemm_compute, i_loop) + k_loop = get_inner_loop(gemm_compute, j_loop) + + gemm_compute, cs = auto_stage_mem( + gemm_compute, k_loop, "C", "C_tile", accum=True, rc=1 + ) + gemm_compute = lift_reduce_constant(gemm_compute, cs.load.expand(0, 1)) + assign = gemm_compute.forward(cs.store).prev() + gemm_compute = inline_assign(gemm_compute, assign) + gemm_compute = set_memory(gemm_compute, cs.alloc, machine.mem_type) + + gemm_compute = tile_loops_bottom_up( + gemm_compute, i_loop, (m_r, n_r, None), tail="guard" + ) + gemm_compute = repeate_n(lift_scope)(gemm_compute, gemm_compute.find_loop("k"), n=2) + return gemm_compute + + +def schedule_macro(gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac): vw = machine.vec_width(precision) n_r = vw * n_r_fac gemm_mk = specialize_precision(gemm_mk, precision) @@ -91,41 +110,17 @@ def schedule_macro( packed_A_shape = ((0, max_M // m_r), (1, max_K), (0, m_r)) gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "A", packed_A_shape, "packed_A", rc=1) # TODO: Schedule packing kernel - gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, "A_pack_kernel") + gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, gemm_mk.name() + "A_pack") packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r)) gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "B", packed_B_shape, "packed_B", rc=1) # TODO: Schedule packing kernel - gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, "B_pack_kernel") - - gemm_mk, _ = extract_subproc(gemm_mk, i_loop, "compute") - return gemm_mk_starter, gemm_mk - n_r = machine.vec_width(precision) * n_r_fac - j_loop = get_inner_loop(gemm_mk, i_loop) - k_loop = get_inner_loop(gemm_mk, j_loop) + gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, gemm_mk.name() + "B_pack") - gemm_mk = auto_stage_mem(gemm_mk, k_loop, "C", "C_tile", accum=True) - gemm_mk = lift_reduce_constant(gemm_mk, gemm_mk.forward(k_loop).expand(1, 0)) - gemm_mk = inline_assign(gemm_mk, gemm_mk.find("C_tile = _ * _")) - - gemm_mk = tile_loops_bottom_up(gemm_mk, i_loop, (m_r, n_r, None)) - gemm_mk = apply(repeate_n(lift_scope))( - gemm_mk, gemm_mk.find_loop("k", many=True), n=2 + gemm_mk = extract_and_schedule(schedule_compute)( + gemm_mk, i_loop, gemm_mk.name() + "_compute", precision, machine, m_r, n_r_fac ) - - tiles = gemm_mk.find("C_tile:_", many=True) - names = ["_uk", "_r_uk", "_b_uk", "_br_uk"] - names = [gemm_mk.name() + su for su in names] - if not do_br: - tiles = tiles[:-1] - names = names[:-1] - for tile, name in zip(tiles, names): - gemm_mk = extract_and_schedule(schedule_micro)( - gemm_mk, tile.expand(), name, precision, machine, m_r, n_r_fac - ) - - gemm_mk = cleanup(gemm_mk) - return gemm_mk + return gemm_mk_starter, gemm_mk def schedule(main_gemm, i_loop, precision, machine): From 4a859cf7ab192d98b160818fac7ee1a5206523d2 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 19 Mar 2024 10:48:53 -0400 Subject: [PATCH 02/14] Schedule the microkernel --- src/level3/gemm.py | 114 ++++++++++++++++----------------------------- 1 file changed, 41 insertions(+), 73 deletions(-) diff --git a/src/level3/gemm.py b/src/level3/gemm.py index 2c7069e..bbf86cb 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -22,61 +22,9 @@ def gemm(M: size, N: size, K: size, alpha: R, A: [R][M, K], B: [R][K, N], C: [R] C[i, j] += alpha * (A[i, k] * B[k, j]) -def specialize_micro(gemm_uk, precision, machine, m_r, n_r_fac): - vw = machine.vec_width(precision) - _, init, main_k, axpy = gemm_uk.body() - - main_i = get_inner_loop(gemm_uk, main_k) - j_loops = [get_inner_loop(gemm_uk, c) for c in (init, main_i, axpy)] - - gemm_uk = simplify(apply(lambda p, c: round_loop(p, c, vw))(gemm_uk, j_loops)) - j_loop = gemm_uk.forward(j_loops[0]) - - specialize_i = not is_loop_bounds_const(gemm_uk, main_i) - specialize_j = not is_loop_bounds_const(gemm_uk, j_loop) - make_conds = lambda e, mx: [f"{expr_to_string(e)} == {i + 1}" for i in range(mx)] - - if specialize_i: - gemm_uk = specialize(gemm_uk, gemm_uk.body(), make_conds(main_i.hi(), m_r)) - if specialize_j: - for tile in gemm_uk.find("C_tile:_", many=True): - gemm_uk = specialize( - gemm_uk, tile.expand(), make_conds(j_loop.hi().lhs(), n_r_fac) - ) - gemm_uk = dce(simplify(gemm_uk)) - return gemm_uk - - -def schedule_micro(gemm_uk, precision, machine, m_r, n_r_fac): - vw = machine.vec_width(precision) - n_r = vw * n_r_fac - - gemm_uk = specialize_micro(gemm_uk, precision, machine, m_r, n_r_fac) - - def rewrite(gemm_uk): - tile, init, main_k, axpy = gemm_uk.body() - main_i = get_inner_loop(gemm_uk, main_k) - main_j = get_inner_loop(gemm_uk, main_i) - m_r = main_i.hi().value() - n_r = main_j.hi().value() - - gemm_uk = rename(gemm_uk, f"{gemm_uk.name()}_{m_r}x{n_r}") - gemm_uk = set_memory(gemm_uk, tile, machine.mem_type) - gemm_uk = divide_dim(gemm_uk, tile, 1, vw) - - loops = [init, main_i, axpy] - - return gemm_uk - - blocks = map(lambda c: c.expand(), gemm_uk.find("C_tile:_", many=True)) - gemm_uk = apply(extract_and_schedule(rewrite))(gemm_uk, blocks, gemm_uk.name()) - return gemm_uk - - def schedule_compute(gemm_compute, precision, machine, m_r, n_r_fac): vw = machine.vec_width(precision) n_r = vw * n_r_fac - i_loop = gemm_compute.body()[0] j_loop = get_inner_loop(gemm_compute, i_loop) k_loop = get_inner_loop(gemm_compute, j_loop) @@ -92,7 +40,35 @@ def schedule_compute(gemm_compute, precision, machine, m_r, n_r_fac): gemm_compute = tile_loops_bottom_up( gemm_compute, i_loop, (m_r, n_r, None), tail="guard" ) - gemm_compute = repeate_n(lift_scope)(gemm_compute, gemm_compute.find_loop("k"), n=2) + gemm_compute = repeate_n(lift_scope)(gemm_compute, k_loop, n=2) + gemm_compute = divide_dim(gemm_compute, cs.alloc, 1, vw) + + loops = gemm_compute.find_loop("ii", many=True) + gemm_compute = apply(optimize_level_2)( + gemm_compute, loops, precision, machine, m_r, n_r_fac, vec_tail="perfect" + ) + + def cut(proc, loop, cond, rng): + loop = proc.forward(loop) + cut_val = FormattedExprStr(f"_ - 1", loop.hi()) + proc, (loop1, loop2) = cut_loop_(proc, loop, cut_val, rc=True) + proc = specialize(proc, loop2.body(), [f"{cond(loop2, i)} == {i}" for i in rng]) + return proc + + right_cond = lambda l, i: f"(N - {l.name()} * {n_r} + {vw - 1}) / {vw}" + gemm_compute = cut(gemm_compute, j_loop, right_cond, range(1, n_r_fac)) + bottom_cond = lambda l, i: f"M - {l.name()} * {m_r}" + gemm_compute = cut(gemm_compute, i_loop, bottom_cond, range(1, m_r)) + + def rewrite(p): + p = simplify(dce(p)) + p = replace_all_stmts(p, machine.get_instructions(precision)) + return p + + for i, tile in enumerate(gemm_compute.find_loop("C_tile:_", many=True)): + name = gemm_compute.name() + str(i) + gemm_compute = extract_and_schedule(rewrite)(gemm_compute, tile.expand(), name) + return gemm_compute @@ -110,41 +86,31 @@ def schedule_macro(gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fa packed_A_shape = ((0, max_M // m_r), (1, max_K), (0, m_r)) gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "A", packed_A_shape, "packed_A", rc=1) # TODO: Schedule packing kernel - gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, gemm_mk.name() + "A_pack") + gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, gemm_mk.name() + "_A_pack") packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r)) gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "B", packed_B_shape, "packed_B", rc=1) # TODO: Schedule packing kernel - gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, gemm_mk.name() + "B_pack") - + gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, gemm_mk.name() + "_B_pack") gemm_mk = extract_and_schedule(schedule_compute)( gemm_mk, i_loop, gemm_mk.name() + "_compute", precision, machine, m_r, n_r_fac ) return gemm_mk_starter, gemm_mk -def schedule(main_gemm, i_loop, precision, machine): - m_r = 4 - n_r_fac = 3 - vw = machine.vec_width(precision) - - M_tile = m_r * 512 - N_tile = n_r_fac * vw * 512 - K_tile = 512 - +def schedule( + main_gemm, i_loop, precision, machine, m_r, n_r_fac, M_tile, N_tile, K_tile +): gemm_macro = schedule_macro( gemm, precision, machine, M_tile, N_tile, K_tile, m_r, n_r_fac ) - gemm_tiled = reorder_loops( - main_gemm, i_loop.body()[0] - ) # Iterate over the main gemm as (i, k, j) + # Iterate over the main gemm as (i, k, j) + gemm_tiled = reorder_loops(main_gemm, i_loop.body()[0]) gemm_tiled = tile_loops_bottom_up(gemm_tiled, i_loop, [M_tile, K_tile, N_tile]) - gemm_tiled = apply(reorder_loops)( - gemm_tiled, gemm_tiled.find_loop("ki", many=True) - ) # Change macrokernel loops to original order - + gemm_tiled = apply(reorder_loops)(gemm_tiled, gemm_tiled.find_loop("ki", many=True)) gemm_tiled = replace_all_stmts(gemm_tiled, [gemm_macro]) + macro_calls = filter_cursors(is_call)(gemm_tiled, nlr_stmts(gemm_tiled)) gemm_tiled = simplify(apply(inline_proc_and_wins)(gemm_tiled, macro_calls)) @@ -156,4 +122,6 @@ def schedule(main_gemm, i_loop, precision, machine): return gemm_tiled -variants_generator(schedule, ("f32",))(gemm, "i", globals=globals()) +variants_generator(schedule, ("f32",))( + gemm, "i", 4, 3, 512, 512, 512, globals=globals() +) From d70145504744ba10ba915a252bc3ea7bcdf6823d Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 19 Mar 2024 19:19:33 -0400 Subject: [PATCH 03/14] Simplify inner kernel --- src/level3/gemm.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/level3/gemm.py b/src/level3/gemm.py index bbf86cb..cd50404 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -58,17 +58,25 @@ def cut(proc, loop, cond, rng): right_cond = lambda l, i: f"(N - {l.name()} * {n_r} + {vw - 1}) / {vw}" gemm_compute = cut(gemm_compute, j_loop, right_cond, range(1, n_r_fac)) bottom_cond = lambda l, i: f"M - {l.name()} * {m_r}" - gemm_compute = cut(gemm_compute, i_loop, bottom_cond, range(1, m_r)) + gemm_compute = cut(gemm_compute, i_loop, bottom_cond, range(m_r, 1, -1)) def rewrite(p): - p = simplify(dce(p)) + p = dce(p) p = replace_all_stmts(p, machine.get_instructions(precision)) + p = simplify(p) + try: + p = delete_pass(p) + except: + pass return p - for i, tile in enumerate(gemm_compute.find_loop("C_tile:_", many=True)): + blocks = gemm_compute.find_loop("C_tile:_", many=True) + for i, tile in enumerate(blocks[:8]): name = gemm_compute.name() + str(i) gemm_compute = extract_and_schedule(rewrite)(gemm_compute, tile.expand(), name) - + # for i, tile in enumerate(blocks[8:]): + # name = gemm_compute.name() + str(8 + i) + # gemm_compute = extract_and_schedule(lambda p:p)(gemm_compute, tile.expand(), name) return gemm_compute @@ -123,5 +131,5 @@ def schedule( variants_generator(schedule, ("f32",))( - gemm, "i", 4, 3, 512, 512, 512, globals=globals() + gemm, "i", 4, 3, 512 * 4, 512 * 24, 512, globals=globals() ) From b54bc6f6c82de803e446aa831597500ac1a4f54c Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 19 Mar 2024 19:20:49 -0400 Subject: [PATCH 04/14] Simplify in tiling --- src/common/stdlib.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/common/stdlib.py b/src/common/stdlib.py index b314387..eb118ad 100644 --- a/src/common/stdlib.py +++ b/src/common/stdlib.py @@ -637,6 +637,7 @@ def push_scope_in(proc, scope, depth, size=None): proc, (_, inner, tail_l) = divide_loop_( proc, loop, tile, tail=tail, rc=True ) + proc = simplify(proc) proc = push_scope_in(proc, inner, depth + 1, tile) proc = push_scope_in(proc, tail_l.body()[0], depth + 1, tile) else: @@ -644,6 +645,7 @@ def push_scope_in(proc, scope, depth, size=None): elif tail == "guard": if tile is not None: proc, (_, inner, _) = divide_loop_(proc, loop, tile, tail=tail, rc=True) + proc = simplify(proc) proc = push_scope_in(proc, inner.body()[0], guards + 1) guards += 1 proc = push_scope_in(proc, inner, depth + guards + 1, tile) From c325670f9bf9604a07e5010266f87cea5d50cbbf Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 19 Mar 2024 23:38:59 -0400 Subject: [PATCH 05/14] Add graphing verbose mode --- analytics_tools/graphing/graph.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/analytics_tools/graphing/graph.py b/analytics_tools/graphing/graph.py index c8af7bf..8966999 100644 --- a/analytics_tools/graphing/graph.py +++ b/analytics_tools/graphing/graph.py @@ -31,12 +31,12 @@ def kernel_graphs_dir(kernel): def help_msg(): - print("python graph.py !") + print("python graph.py !") exit(1) def check_args(): - if len(sys.argv) != 2: + if len(sys.argv) > 3: help_msg() @@ -146,7 +146,7 @@ def plot_bandwidth_throughput(kernel, data, peaks, loads=True): plt.savefig(filename) -def plot_flops_throughput(kernel, data, peaks): +def plot_flops_throughput(kernel, data, peaks, verbose): plt.clf() some_point = next(iter(data.values()))[0] @@ -163,6 +163,8 @@ def plot_flops_throughput(kernel, data, peaks): assert len(runs) == len(set(runs)) # No duplicates x = [run.get_size_param() for run in sorted_runs] y = [run.get_gflops_per_sec() for run in sorted_runs] + if verbose: + print(libname, ": ", list(zip(x, y))) plt.plot(x, y, label=libname) plt.legend() @@ -185,6 +187,7 @@ def plot_flops_throughput(kernel, data, peaks): check_args() kernel = sys.argv[1] + verbose = sys.argv[2] init_directories(kernel) jsons = get_jsons(kernel) @@ -196,4 +199,4 @@ def plot_flops_throughput(kernel, data, peaks): for data in bench_type_dict.values(): plot_bandwidth_throughput(kernel, data, peaks, loads=True) plot_bandwidth_throughput(kernel, data, peaks, loads=False) - plot_flops_throughput(kernel, data, peaks) + plot_flops_throughput(kernel, data, peaks, verbose) From 01817c03e69d23657ba264290ed9e78bd3715839 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 19 Mar 2024 23:39:54 -0400 Subject: [PATCH 06/14] Handle all micro cases --- src/level3/gemm.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/level3/gemm.py b/src/level3/gemm.py index cd50404..2b9971f 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -3,6 +3,7 @@ from exo import * from exo.stdlib.scheduling import * from exo.API_cursors import * +from exo.libs.memories import DRAM_STATIC import exo_blas_config as C from stdlib import * @@ -71,12 +72,9 @@ def rewrite(p): return p blocks = gemm_compute.find_loop("C_tile:_", many=True) - for i, tile in enumerate(blocks[:8]): + for i, tile in enumerate(blocks): name = gemm_compute.name() + str(i) gemm_compute = extract_and_schedule(rewrite)(gemm_compute, tile.expand(), name) - # for i, tile in enumerate(blocks[8:]): - # name = gemm_compute.name() + str(8 + i) - # gemm_compute = extract_and_schedule(lambda p:p)(gemm_compute, tile.expand(), name) return gemm_compute @@ -93,13 +91,14 @@ def schedule_macro(gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fa packed_A_shape = ((0, max_M // m_r), (1, max_K), (0, m_r)) gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "A", packed_A_shape, "packed_A", rc=1) - # TODO: Schedule packing kernel + gemm_mk = set_memory(gemm_mk, cursors.alloc, DRAM_STATIC) gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, gemm_mk.name() + "_A_pack") packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r)) gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "B", packed_B_shape, "packed_B", rc=1) - # TODO: Schedule packing kernel + gemm_mk = set_memory(gemm_mk, cursors.alloc, DRAM_STATIC) gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, gemm_mk.name() + "_B_pack") + gemm_mk = extract_and_schedule(schedule_compute)( gemm_mk, i_loop, gemm_mk.name() + "_compute", precision, machine, m_r, n_r_fac ) @@ -130,6 +129,15 @@ def schedule( return gemm_tiled +m_r = 4 +n_r_fac = 3 +n_r = n_r_fac * C.Machine.vec_width("f32") +M_tile_fac = 4096 // m_r +N_tile_fac = 4096 // n_r +M_tile = M_tile_fac * m_r +N_tile = N_tile_fac * n_r +K_tile = 32 + variants_generator(schedule, ("f32",))( - gemm, "i", 4, 3, 512 * 4, 512 * 24, 512, globals=globals() + gemm, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals() ) From afb4a1042e611e3b1183ee5378f66db7fdc06710 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 19 Mar 2024 23:40:38 -0400 Subject: [PATCH 07/14] Contrl MKL Instr target --- CMakePresets.json | 7 ++++--- test/CMakeLists.txt | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index d04b4e6..2090979 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -38,11 +38,12 @@ "generator": "Ninja", "binaryDir": "${sourceDir}/build/avx2", "cacheVariables": { - "CMAKE_C_FLAGS": "-march=native -fno-tree-vectorize -fno-unroll-loops -O3 -ffast-math -std=c11", - "CMAKE_CXX_FLAGS": "-std=c++17 -march=native -fno-tree-vectorize -fno-unroll-loops -O3 -ffast-math", + "CMAKE_C_FLAGS": "-march=native -O3 -ffast-math -std=c11", + "CMAKE_CXX_FLAGS": "-std=c++17 -march=native -O3 -ffast-math", "CXX_STANDARD" : "C++17", "CMAKE_BUILD_TYPE": "Release", - "TARGET_ARCH": "avx2" + "TARGET_ARCH": "avx2", + "MKL_ENABLE_INSTRUCTIONS": "AVX2" } } ] diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 38c9f2c..b300ef6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -52,14 +52,14 @@ function(add_exo_blas_test level kernel precision) set_tests_properties( cblas_${precision_kernel}_bench PROPERTIES - ENVIRONMENT "OPENBLAS_NUM_THREADS=1;MKL_NUM_THREADS=1;VECLIB_MAXIMUM_THREADS=1;BENCHMARK_FORMAT=console;BENCHMARK_OUT=${bla_vendor_benchmark_json_output}" + ENVIRONMENT "OPENBLAS_NUM_THREADS=1;MKL_NUM_THREADS=1;VECLIB_MAXIMUM_THREADS=1;BENCHMARK_FORMAT=console;BENCHMARK_OUT=${bla_vendor_benchmark_json_output};MKL_ENABLE_INSTRUCTIONS=${MKL_ENABLE_INSTRUCTIONS}" ) add_test(NAME exo_${precision_kernel}_bench COMMAND ${precision_kernel}_bench --benchmark_filter=exo|EXO ${benchmark_min_warmup_time} ${benchmark_min_time}) set_tests_properties( exo_${precision_kernel}_bench PROPERTIES - ENVIRONMENT "OPENBLAS_NUM_THREADS=1;MKL_NUM_THREADS=1;VECLIB_MAXIMUM_THREADS=1;BENCHMARK_FORMAT=console;BENCHMARK_OUT=${exo_benchmark_json_output}" + ENVIRONMENT "BENCHMARK_OUT=${exo_benchmark_json_output}" ) # Add the correctness test From 82ce08cbcde5ede5cd03b1f88e53626f71199757 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Wed, 20 Mar 2024 10:23:01 -0400 Subject: [PATCH 08/14] Compile for generic avx2 backends --- CMakePresets.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index 2090979..3d10a83 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -38,8 +38,8 @@ "generator": "Ninja", "binaryDir": "${sourceDir}/build/avx2", "cacheVariables": { - "CMAKE_C_FLAGS": "-march=native -O3 -ffast-math -std=c11", - "CMAKE_CXX_FLAGS": "-std=c++17 -march=native -O3 -ffast-math", + "CMAKE_C_FLAGS": "-march=core-avx2 -O3 -ffast-math -std=c11", + "CMAKE_CXX_FLAGS": "-std=c++17 -march=core-avx2 -O3 -ffast-math", "CXX_STANDARD" : "C++17", "CMAKE_BUILD_TYPE": "Release", "TARGET_ARCH": "avx2", From b8dd14b29ce1b4d60048f538bf54d49c572bc347 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Wed, 20 Mar 2024 10:23:34 -0400 Subject: [PATCH 09/14] Enable tiling the packing routines --- src/common/stdlib.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/common/stdlib.py b/src/common/stdlib.py index eb118ad..706617d 100644 --- a/src/common/stdlib.py +++ b/src/common/stdlib.py @@ -1396,7 +1396,9 @@ def add_loop_nest(loop): for src_dim, _ in reversed(list(enumerate(alloc.shape()))): for dst_dim, size in reversed(divisions[src_dim][1:]): proc = divide_dim(proc, alloc, src_dim, size) - # proc = apply(divide_loop_)(proc, loop_nest[src_dim], size, tail="cut") # TODO: This should be enabled but it slows down compilation + proc = apply(divide_loop_)( + proc, loop_nest[src_dim], size, tail="cut" + ) # TODO: This should be enabled but it slows down compilation perm.append(dst_dim) perm.append(divisions[src_dim][0][0]) From 68bdfde892cff515523a05f78129097c756d308b Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Wed, 20 Mar 2024 10:24:39 -0400 Subject: [PATCH 10/14] Change outer gemm structure --- src/level3/gemm.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/level3/gemm.py b/src/level3/gemm.py index 2b9971f..3440081 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -112,10 +112,13 @@ def schedule( gemm, precision, machine, M_tile, N_tile, K_tile, m_r, n_r_fac ) - # Iterate over the main gemm as (i, k, j) - gemm_tiled = reorder_loops(main_gemm, i_loop.body()[0]) - gemm_tiled = tile_loops_bottom_up(gemm_tiled, i_loop, [M_tile, K_tile, N_tile]) - gemm_tiled = apply(reorder_loops)(gemm_tiled, gemm_tiled.find_loop("ki", many=True)) + gemm_tiled = main_gemm + k_loop = get_inner_loop(gemm_tiled, get_inner_loop(gemm_tiled, i_loop)) + gemm_tiled = repeate_n(lift_scope)(gemm_tiled, k_loop, n=2) + gemm_tiled = tile_loops_bottom_up(gemm_tiled, k_loop, [K_tile, M_tile, N_tile]) + gemm_tiled = apply(repeate_n(reorder_loops))( + gemm_tiled, gemm_tiled.find_loop("ki", many=True), n=2 + ) gemm_tiled = replace_all_stmts(gemm_tiled, [gemm_macro]) macro_calls = filter_cursors(is_call)(gemm_tiled, nlr_stmts(gemm_tiled)) @@ -132,11 +135,11 @@ def schedule( m_r = 4 n_r_fac = 3 n_r = n_r_fac * C.Machine.vec_width("f32") -M_tile_fac = 4096 // m_r -N_tile_fac = 4096 // n_r +M_tile_fac = 66 +N_tile_fac = 3 M_tile = M_tile_fac * m_r N_tile = N_tile_fac * n_r -K_tile = 32 +K_tile = 512 variants_generator(schedule, ("f32",))( gemm, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals() From eb2752b47f63acd1326b4a8bdeba775cfe086c47 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Wed, 20 Mar 2024 11:12:25 -0400 Subject: [PATCH 11/14] Control level 1 op instrs from caller --- src/common/blaslib.py | 6 ++++-- src/level3/gemm.py | 9 ++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/common/blaslib.py b/src/common/blaslib.py index ebfc8bf..ff03848 100644 --- a/src/common/blaslib.py +++ b/src/common/blaslib.py @@ -19,12 +19,14 @@ def optimize_level_1( precision, machine, interleave_factor, + instrs=None, vec_tail=None, inter_tail="recursive", ): vec_width = machine.vec_width(precision) memory = machine.mem_type - instructions = machine.get_instructions(precision) + if instrs is None: + instrs = machine.get_instructions(precision) if vec_tail is None: vec_tail = "cut_and_predicate" if machine.supports_predication else "cut" @@ -44,7 +46,7 @@ def optimize_level_1( ) proc = cleanup(proc) - proc = replace_all_stmts(proc, instructions) + proc = replace_all_stmts(proc, instrs) return proc diff --git a/src/level3/gemm.py b/src/level3/gemm.py index 3440081..cc70560 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -46,7 +46,14 @@ def schedule_compute(gemm_compute, precision, machine, m_r, n_r_fac): loops = gemm_compute.find_loop("ii", many=True) gemm_compute = apply(optimize_level_2)( - gemm_compute, loops, precision, machine, m_r, n_r_fac, vec_tail="perfect" + gemm_compute, + loops, + precision, + machine, + m_r, + n_r_fac, + instrs=[], + vec_tail="perfect", ) def cut(proc, loop, cond, rng): From 5c9f6a0d2851740c0c0020834bda234e63d927fe Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Wed, 20 Mar 2024 11:12:58 -0400 Subject: [PATCH 12/14] Update codegen test --- test/codegen/reference/sha256/avx2.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/codegen/reference/sha256/avx2.json b/test/codegen/reference/sha256/avx2.json index 4217b9e..81881b3 100644 --- a/test/codegen/reference/sha256/avx2.json +++ b/test/codegen/reference/sha256/avx2.json @@ -5,7 +5,7 @@ "exo_dot": "8e7a71353e80273839bc522cfc8c889cbe25d8701a8c82a3c2559b12d9e90f5f", "exo_dsdot": "c901be9d30928e042c35aeb9dab421a34db6d593534f31b4967f9685ecf9628b", "exo_gbmv": "cb92744337cfdf3aa97d250a18e540c2e8787380ba88bfef790a3e14aeb19f37", - "exo_gemm": "fff2030c7cf1d83f841d368235c396387057dd27b29512559aba468dd662500f", + "exo_gemm": "89b5ac6074fa7eca77a59fbebbe8a143456892209ced6c2d20f319e9987adbb3", "exo_gemv": "80d79a5752c20874fbe3d2c94989f5f5663687d405966065a18224800c931559", "exo_ger": "38ba46c410ea9bd1d616add516e9ab82ed811e5785848bce17191d89f81752ce", "exo_iamax": "49c60714c479234683166e5651fbe95ed5a43ecd370a391732769588948cc842", From a80d73e5d818c809d06184d5fc234405f3d0b117 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Wed, 20 Mar 2024 11:20:07 -0400 Subject: [PATCH 13/14] Add control for targets at codegen --- src/common/codegen_helpers.py | 4 ++-- src/level1/asum.py | 8 +------- src/level3/gemm.py | 2 +- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/common/codegen_helpers.py b/src/common/codegen_helpers.py index c9bdae7..2461ddb 100644 --- a/src/common/codegen_helpers.py +++ b/src/common/codegen_helpers.py @@ -110,7 +110,7 @@ def export_perf_features(kernel_name, perf_features): json.dump(perf_features, f, sort_keys=True, indent=4, separators=(",", ": ")) -def variants_generator(blas_op, opt_precisions=("f32", "f64")): +def variants_generator(blas_op, opt_precisions=("f32", "f64"), targets=(AVX2, Neon)): def generate(proc, loop_name, *args, globals=None, **kwargs): perf_features = {} for precision in ("f32", "f64"): @@ -126,7 +126,7 @@ def generate(proc, loop_name, *args, globals=None, **kwargs): loop = stride_1.find_loop(loop_name) algorithm = get_perf_features(stride_1) - if precision in opt_precisions: + if precision in opt_precisions and C.Machine.mem_type in targets: stride_1 = blas_op( stride_1, loop, precision, C.Machine, *args, **kwargs ) diff --git a/src/level1/asum.py b/src/level1/asum.py index f455689..e74a653 100644 --- a/src/level1/asum.py +++ b/src/level1/asum.py @@ -17,11 +17,5 @@ def asum(n: size, x: [f32][n] @ DRAM, result: f32 @ DRAM): ### EXO_LOC ALGORITHM END ### ### EXO_LOC SCHEDULE START ### -def schedule_asum(asum, loop, precision, machine, interleave_factor): - if machine.mem_type is not AVX2: - return asum - return optimize_level_1(asum, loop, precision, machine, interleave_factor) - - -variants_generator(schedule_asum)(asum, "i", 8, globals=globals()) +variants_generator(optimize_level_1, targets=(AVX2,))(asum, "i", 8, globals=globals()) ### EXO_LOC SCHEDULE END ### diff --git a/src/level3/gemm.py b/src/level3/gemm.py index cc70560..b355d14 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -148,6 +148,6 @@ def schedule( N_tile = N_tile_fac * n_r K_tile = 512 -variants_generator(schedule, ("f32",))( +variants_generator(schedule, ("f32",), (AVX2,))( gemm, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals() ) From bcecf0b9f6eba0a55c6561bd4d33560395c9dcea Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Wed, 20 Mar 2024 12:09:31 -0400 Subject: [PATCH 14/14] Update codegen tests --- src/level3/gemm.py | 6 +++--- test/codegen/reference/sha256/avx2.json | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/level3/gemm.py b/src/level3/gemm.py index b355d14..c8e6481 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -82,7 +82,7 @@ def rewrite(p): for i, tile in enumerate(blocks): name = gemm_compute.name() + str(i) gemm_compute = extract_and_schedule(rewrite)(gemm_compute, tile.expand(), name) - return gemm_compute + return simplify(gemm_compute) def schedule_macro(gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac): @@ -109,7 +109,7 @@ def schedule_macro(gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fa gemm_mk = extract_and_schedule(schedule_compute)( gemm_mk, i_loop, gemm_mk.name() + "_compute", precision, machine, m_r, n_r_fac ) - return gemm_mk_starter, gemm_mk + return gemm_mk_starter, simplify(gemm_mk) def schedule( @@ -136,7 +136,7 @@ def schedule( ) gemm_tiled = squash_buffers(gemm_tiled, gemm_tiled.find("packed_A : _", many=True)) gemm_tiled = squash_buffers(gemm_tiled, gemm_tiled.find("packed_B : _", many=True)) - return gemm_tiled + return simplify(gemm_tiled) m_r = 4 diff --git a/test/codegen/reference/sha256/avx2.json b/test/codegen/reference/sha256/avx2.json index 81881b3..1cce1dc 100644 --- a/test/codegen/reference/sha256/avx2.json +++ b/test/codegen/reference/sha256/avx2.json @@ -5,7 +5,7 @@ "exo_dot": "8e7a71353e80273839bc522cfc8c889cbe25d8701a8c82a3c2559b12d9e90f5f", "exo_dsdot": "c901be9d30928e042c35aeb9dab421a34db6d593534f31b4967f9685ecf9628b", "exo_gbmv": "cb92744337cfdf3aa97d250a18e540c2e8787380ba88bfef790a3e14aeb19f37", - "exo_gemm": "89b5ac6074fa7eca77a59fbebbe8a143456892209ced6c2d20f319e9987adbb3", + "exo_gemm": "67f63149439fd0ba8e0761b1d1ff9f20a8b6d443780328e75df5cfaaf50675db", "exo_gemv": "80d79a5752c20874fbe3d2c94989f5f5663687d405966065a18224800c931559", "exo_ger": "38ba46c410ea9bd1d616add516e9ab82ed811e5785848bce17191d89f81752ce", "exo_iamax": "49c60714c479234683166e5651fbe95ed5a43ecd370a391732769588948cc842",