From f0aab7e5e984eecce2436d2fef9b4b9f36838497 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Fri, 15 Mar 2024 15:45:20 -0400 Subject: [PATCH] Schedule gemm tiling --- src/common/higher_order.py | 8 ++-- src/common/inspection.py | 21 ++++++++- src/common/rc_wrappers.py | 6 +++ src/common/stdlib.py | 83 +++++++++++++++++++++++------------ src/level3/gemm.py | 90 ++++++++++++++++++++++++++++++-------- 5 files changed, 157 insertions(+), 51 deletions(-) diff --git a/src/common/higher_order.py b/src/common/higher_order.py index 6f6b192..7e4b8f2 100644 --- a/src/common/higher_order.py +++ b/src/common/higher_order.py @@ -76,14 +76,16 @@ def rewrite(p, *args, n=1, **kwargs): def extract_and_schedule(op): - def rewrite(proc, block, subproc_name, *args, **kwargs): + def rewrite(proc, block, subproc_name, *args, rc=False, **kwargs): block = proc.forward(block) block = block.as_block() proc, subproc = extract_subproc(proc, block, subproc_name) - subproc = op(subproc, *args, **kwargs) + subproc_sched = op(subproc, *args, **kwargs) call = proc.forward(block)[0] proc = call_eqv(proc, call, subproc) - return proc + if not rc: + return proc + return proc, (subproc, subproc_sched) return rewrite diff --git a/src/common/inspection.py b/src/common/inspection.py index e5295f0..cb4fca7 100644 --- a/src/common/inspection.py +++ b/src/common/inspection.py @@ -48,7 +48,7 @@ def stmt_children(stmt): elif isinstance(stmt, CallCursor): yield from stmt.args() elif isinstance(stmt, WindowStmtCursor): - yield from stmt.idx() + yield stmt.winexpr() elif isinstance(stmt, AssignConfigCursor): yield stmt.rhs() elif isinstance(stmt, PassCursor): @@ -380,6 +380,23 @@ def is_unary_minus(proc, expr): return isinstance(expr, UnaryMinusCursor) +def is_call(proc, call, subproc=None): + call = proc.forward(call) + return isinstance(call, CallCursor) and ( + subproc is None or call.subproc() == subproc + ) + + +def is_invalid(proc, inv): + if isinstance(inv, InvalidCursor): + return True + try: + inv = proc.forward(inv) + return False + except InvalidCursorError: + return True + + def is_start_of_body(proc, stmt): stmt = proc.forward(stmt) return isinstance(stmt.prev(), InvalidCursor) @@ -394,7 +411,7 @@ def get_depth(proc, cursor): cursor = proc.forward(cursor) depth = 1 - while not isinstance(cursor, InvalidCursor): + while not isinstance(cursor.parent(), InvalidCursor): cursor = cursor.parent() depth += 1 return depth diff --git a/src/common/rc_wrappers.py b/src/common/rc_wrappers.py index b779799..a7a8f31 100644 --- a/src/common/rc_wrappers.py +++ b/src/common/rc_wrappers.py @@ -88,6 +88,12 @@ class stage_mem_cursors: block: BlockCursor store_stage: Cursor + def __iter__(self): + yield self.alloc + yield self.load_stage + yield self.block + yield self.store_stage + def stage_mem_(proc, block, buff, new_buff_name, accum=False, rc=False): block = proc.forward(block) diff --git a/src/common/stdlib.py b/src/common/stdlib.py index 99f5fa2..197e047 100644 --- a/src/common/stdlib.py +++ b/src/common/stdlib.py @@ -586,7 +586,7 @@ def tile_loops_top_down(proc, loop_tile_pairs): def tile_loops_bottom_up(proc, loop, tiles, const_allocs=True): - + loop = proc.forward(loop) cur_loop = loop for i in tiles[:-1]: if not len(cur_loop.body()) == 1: @@ -1012,33 +1012,19 @@ def __iter__(self): yield self.call -def checked_replace(proc, stmt, subproc, quiet=False): +def checked_replace(proc, stmt, subproc, quiet=False, rc=False): stmt = proc.forward(stmt) - parent = stmt.parent() - index = get_index_in_body(proc, stmt) - + block = stmt.as_block() try: proc = replace(proc, stmt, subproc, quiet=quiet) except: raise BLAS_SchedulingError("failed to replace") - - is_else = False - if ( - isinstance(parent, IfCursor) - and not isinstance(parent.orelse(), InvalidCursor) - and index < len(parent.orelse()) - and parent.orelse()[index] == stmt - ): - is_else = True - if not isinstance(parent, InvalidCursor): - parent = proc.forward(parent) - else: - parent = proc - - call = parent.body()[index] if not is_else else parent.orelse()[index] + call = proc.forward(block)[0] if not check_call_site(proc, call): raise BLAS_SchedulingError("Call site inconsistency") - return proc + if not rc: + return proc + return proc, replace_cursors(call) def replace_all_stmts(proc, instructions): @@ -1050,12 +1036,17 @@ def replace_all_stmts(proc, instructions): stmt = proc.forward(stmt) except InvalidCursorError: continue - for instr in instructions: + for val in instructions: + if isinstance(val, tuple): + instr, eqv = val + else: + instr = val + eqv = val try: - proc = checked_replace(proc, stmt, instr, quiet=True) - break + proc, (call,) = checked_replace(proc, stmt, instr, quiet=True, rc=True) except BLAS_SchedulingError: - pass + continue + proc = call_eqv(proc, call, eqv) return proc @@ -1190,7 +1181,6 @@ def binary_specialize(proc, block, expr, values, rc=False): def rewrite(proc, block, values): block = proc.forward(block) - print(block) if len(values) == 1: # This should be redundant if the user provided correct inputs! # So, it is really a check that the inputs the user provided cover the full range. @@ -1314,8 +1304,47 @@ def cut_loop_and_unroll(proc, loop, const, front=True, rc=False): def bound_alloc(proc, alloc, bounds): - alloc = porc.forward(alloc) + alloc = proc.forward(alloc) for idx, bound in enumerate(bounds): if bound is not None: proc = resize_dim(proc, alloc, idx, bound, 0) return proc + + +@dataclass +class inline_proc_and_wins_cursors: + block: BlockCursor + + def __iter__(self): + yield self.block + + +def inline_proc_and_wins(proc, call, rc=False): + call = proc.forward(call) + proc = inline(proc, call) + block = proc.forward(call.as_block()) + windows = filter(lambda s: isinstance(s, WindowStmtCursor), block) + proc = apply(inline_window)(proc, windows) + if not rc: + return proc + return proc, inline_proc_and_wins_cursors(proc.forward(block)) + + +def squash_buffers(proc, buffers): + buffers = [proc.forward(b) for b in buffers] + buffer = buffers[0] + depths = [get_depth(proc, b) for b in buffers] + + squash = apply(attempt(lambda p, e, b: reuse_buffer(p, b, e))) + proc = squash(proc, buffers[1:], buffer) + max_d = max(depths) + while max_d > 1: + for i, (b, d) in enumerate(zip(buffers, depths)): + if max_d != d: + continue + depths[i] -= 1 + if not is_invalid(proc, b): + proc = lift_alloc(proc, b) + proc = squash(proc, buffers[1:], buffer) + max_d = max(depths) + return proc diff --git a/src/level3/gemm.py b/src/level3/gemm.py index 9c0a541..9947d60 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -75,38 +75,90 @@ def rewrite(gemm_uk): return gemm_uk -def schedule_macro(gemm, i_loop, precision, machine, m_r, n_r_fac, do_br=False): +def schedule_macro( + gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac, do_br=False +): + gemm_mk = specialize_precision(gemm_mk, precision) + for var, max_var in zip(("M", "N", "K"), (max_M, max_N, max_K)): + gemm_mk = gemm_mk.add_assertion(f"{var} <= {max_var}") + gemm_mk_starter = gemm_mk + gemm_mk = rename(gemm_mk, gemm_mk.name() + "_mk") + i_loop = gemm_mk.body()[0] + gemm_mk, (A_alloc, A_load, _, _) = auto_stage_mem( + gemm_mk, i_loop, "A", "packed_A", rc=True + ) + gemm_mk, (B_alloc, B_load, _, _) = auto_stage_mem( + gemm_mk, i_loop, "B", "packed_B", rc=True + ) + gemm_mk = bound_alloc(gemm_mk, A_alloc, (max_M, max_K)) + + gemm_mk = bound_alloc(gemm_mk, B_alloc, (max_K, max_N)) + gemm_mk, _ = extract_subproc(gemm_mk, A_load, "A_pack_kernel") + gemm_mk, _ = extract_subproc(gemm_mk, B_load, "B_pack_kernel") + gemm_mk, _ = extract_subproc(gemm_mk, i_loop, "compute") + + return gemm_mk_starter, gemm_mk n_r = machine.vec_width(precision) * n_r_fac + j_loop = get_inner_loop(gemm_mk, i_loop) + k_loop = get_inner_loop(gemm_mk, j_loop) - j_loop = get_inner_loop(gemm, i_loop) - k_loop = get_inner_loop(gemm, j_loop) + gemm_mk = auto_stage_mem(gemm_mk, k_loop, "C", "C_tile", accum=True) + gemm_mk = lift_reduce_constant(gemm_mk, gemm_mk.forward(k_loop).expand(1, 0)) + gemm_mk = inline_assign(gemm_mk, gemm_mk.find("C_tile = _ * _")) - gemm = auto_stage_mem(gemm, k_loop, "C", "C_tile", accum=True) - gemm = lift_reduce_constant(gemm, gemm.forward(k_loop).expand(1, 0)) - gemm = inline_assign(gemm, gemm.find("C_tile = _ * _")) + gemm_mk = tile_loops_bottom_up(gemm_mk, i_loop, (m_r, n_r, None)) + gemm_mk = apply(repeate_n(lift_scope))( + gemm_mk, gemm_mk.find_loop("k", many=True), n=2 + ) - gemm = tile_loops_bottom_up(gemm, i_loop, (m_r, n_r, None)) - gemm = apply(repeate_n(lift_scope))(gemm, gemm.find_loop("k", many=True), n=2) - - tiles = gemm.find("C_tile:_", many=True) + tiles = gemm_mk.find("C_tile:_", many=True) names = ["_uk", "_r_uk", "_b_uk", "_br_uk"] - names = [gemm.name() + su for su in names] + names = [gemm_mk.name() + su for su in names] if not do_br: tiles = tiles[:-1] names = names[:-1] for tile, name in zip(tiles, names): - gemm = extract_and_schedule(schedule_micro)( - gemm, tile.expand(), name, precision, machine, m_r, n_r_fac + gemm_mk = extract_and_schedule(schedule_micro)( + gemm_mk, tile.expand(), name, precision, machine, m_r, n_r_fac ) - gemm = cleanup(gemm) - print(gemm) - return gemm + gemm_mk = cleanup(gemm_mk) + return gemm_mk + +def schedule(main_gemm, i_loop, precision, machine): + m_r = 4 + n_r_fac = 3 + vw = machine.vec_width(precision) -def schedule(gemm, i_loop, precision, machine): - macro = schedule_macro(gemm, i_loop, precision, machine, 4, 3) - return macro + M_tile = m_r * 512 + N_tile = n_r_fac * vw * 512 + K_tile = 512 + + gemm_macro = schedule_macro( + gemm, precision, machine, M_tile, N_tile, K_tile, m_r, n_r_fac + ) + + gemm_tiled = reorder_loops( + main_gemm, i_loop.body()[0] + ) # Iterate over the main gemm as (i, k, j) + gemm_tiled = tile_loops_bottom_up(gemm_tiled, i_loop, [M_tile, K_tile, N_tile]) + gemm_tiled = apply(reorder_loops)( + gemm_tiled, gemm_tiled.find_loop("ki", many=True) + ) # Change macrokernel loops to original order + + gemm_tiled = replace_all_stmts(gemm_tiled, [gemm_macro]) + macro_calls = filter( + lambda c: is_call(gemm_tiled, c, gemm_macro[1]), nlr_stmts(gemm_tiled) + ) + gemm_tiled = simplify(apply(inline_proc_and_wins)(gemm_tiled, macro_calls)) + + gemm_tiled = apply(hoist_from_loop)( + gemm_tiled, gemm_tiled.find_loop("jo", many=True) + ) + gemm_tiled = squash_buffers(gemm_tiled, gemm_tiled.find("packed_A : _", many=True)) + gemm_tiled = squash_buffers(gemm_tiled, gemm_tiled.find("packed_B : _", many=True)) + return gemm_tiled variants_generator(schedule, ("f32",))(gemm, "i", globals=globals())