Skip to content

Commit

Permalink
Schedule gemm tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi committed Mar 15, 2024
1 parent cd292d1 commit f0aab7e
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 51 deletions.
8 changes: 5 additions & 3 deletions src/common/higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ def rewrite(p, *args, n=1, **kwargs):


def extract_and_schedule(op):
def rewrite(proc, block, subproc_name, *args, **kwargs):
def rewrite(proc, block, subproc_name, *args, rc=False, **kwargs):
block = proc.forward(block)
block = block.as_block()
proc, subproc = extract_subproc(proc, block, subproc_name)
subproc = op(subproc, *args, **kwargs)
subproc_sched = op(subproc, *args, **kwargs)
call = proc.forward(block)[0]
proc = call_eqv(proc, call, subproc)
return proc
if not rc:
return proc
return proc, (subproc, subproc_sched)

return rewrite

Expand Down
21 changes: 19 additions & 2 deletions src/common/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def stmt_children(stmt):
elif isinstance(stmt, CallCursor):
yield from stmt.args()
elif isinstance(stmt, WindowStmtCursor):
yield from stmt.idx()
yield stmt.winexpr()
elif isinstance(stmt, AssignConfigCursor):
yield stmt.rhs()
elif isinstance(stmt, PassCursor):
Expand Down Expand Up @@ -380,6 +380,23 @@ def is_unary_minus(proc, expr):
return isinstance(expr, UnaryMinusCursor)


def is_call(proc, call, subproc=None):
call = proc.forward(call)
return isinstance(call, CallCursor) and (
subproc is None or call.subproc() == subproc
)


def is_invalid(proc, inv):
if isinstance(inv, InvalidCursor):
return True
try:
inv = proc.forward(inv)
return False
except InvalidCursorError:
return True


def is_start_of_body(proc, stmt):
stmt = proc.forward(stmt)
return isinstance(stmt.prev(), InvalidCursor)
Expand All @@ -394,7 +411,7 @@ def get_depth(proc, cursor):
cursor = proc.forward(cursor)

depth = 1
while not isinstance(cursor, InvalidCursor):
while not isinstance(cursor.parent(), InvalidCursor):
cursor = cursor.parent()
depth += 1
return depth
Expand Down
6 changes: 6 additions & 0 deletions src/common/rc_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class stage_mem_cursors:
block: BlockCursor
store_stage: Cursor

def __iter__(self):
yield self.alloc
yield self.load_stage
yield self.block
yield self.store_stage


def stage_mem_(proc, block, buff, new_buff_name, accum=False, rc=False):
block = proc.forward(block)
Expand Down
83 changes: 56 additions & 27 deletions src/common/stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def tile_loops_top_down(proc, loop_tile_pairs):


def tile_loops_bottom_up(proc, loop, tiles, const_allocs=True):

loop = proc.forward(loop)
cur_loop = loop
for i in tiles[:-1]:
if not len(cur_loop.body()) == 1:
Expand Down Expand Up @@ -1012,33 +1012,19 @@ def __iter__(self):
yield self.call


def checked_replace(proc, stmt, subproc, quiet=False):
def checked_replace(proc, stmt, subproc, quiet=False, rc=False):
stmt = proc.forward(stmt)
parent = stmt.parent()
index = get_index_in_body(proc, stmt)

block = stmt.as_block()
try:
proc = replace(proc, stmt, subproc, quiet=quiet)
except:
raise BLAS_SchedulingError("failed to replace")

is_else = False
if (
isinstance(parent, IfCursor)
and not isinstance(parent.orelse(), InvalidCursor)
and index < len(parent.orelse())
and parent.orelse()[index] == stmt
):
is_else = True
if not isinstance(parent, InvalidCursor):
parent = proc.forward(parent)
else:
parent = proc

call = parent.body()[index] if not is_else else parent.orelse()[index]
call = proc.forward(block)[0]
if not check_call_site(proc, call):
raise BLAS_SchedulingError("Call site inconsistency")
return proc
if not rc:
return proc
return proc, replace_cursors(call)


def replace_all_stmts(proc, instructions):
Expand All @@ -1050,12 +1036,17 @@ def replace_all_stmts(proc, instructions):
stmt = proc.forward(stmt)
except InvalidCursorError:
continue
for instr in instructions:
for val in instructions:
if isinstance(val, tuple):
instr, eqv = val
else:
instr = val
eqv = val
try:
proc = checked_replace(proc, stmt, instr, quiet=True)
break
proc, (call,) = checked_replace(proc, stmt, instr, quiet=True, rc=True)
except BLAS_SchedulingError:
pass
continue
proc = call_eqv(proc, call, eqv)
return proc


Expand Down Expand Up @@ -1190,7 +1181,6 @@ def binary_specialize(proc, block, expr, values, rc=False):

def rewrite(proc, block, values):
block = proc.forward(block)
print(block)
if len(values) == 1:
# This should be redundant if the user provided correct inputs!
# So, it is really a check that the inputs the user provided cover the full range.
Expand Down Expand Up @@ -1314,8 +1304,47 @@ def cut_loop_and_unroll(proc, loop, const, front=True, rc=False):


def bound_alloc(proc, alloc, bounds):
alloc = porc.forward(alloc)
alloc = proc.forward(alloc)
for idx, bound in enumerate(bounds):
if bound is not None:
proc = resize_dim(proc, alloc, idx, bound, 0)
return proc


@dataclass
class inline_proc_and_wins_cursors:
block: BlockCursor

def __iter__(self):
yield self.block


def inline_proc_and_wins(proc, call, rc=False):
call = proc.forward(call)
proc = inline(proc, call)
block = proc.forward(call.as_block())
windows = filter(lambda s: isinstance(s, WindowStmtCursor), block)
proc = apply(inline_window)(proc, windows)
if not rc:
return proc
return proc, inline_proc_and_wins_cursors(proc.forward(block))


def squash_buffers(proc, buffers):
buffers = [proc.forward(b) for b in buffers]
buffer = buffers[0]
depths = [get_depth(proc, b) for b in buffers]

squash = apply(attempt(lambda p, e, b: reuse_buffer(p, b, e)))
proc = squash(proc, buffers[1:], buffer)
max_d = max(depths)
while max_d > 1:
for i, (b, d) in enumerate(zip(buffers, depths)):
if max_d != d:
continue
depths[i] -= 1
if not is_invalid(proc, b):
proc = lift_alloc(proc, b)
proc = squash(proc, buffers[1:], buffer)
max_d = max(depths)
return proc
90 changes: 71 additions & 19 deletions src/level3/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,38 +75,90 @@ def rewrite(gemm_uk):
return gemm_uk


def schedule_macro(gemm, i_loop, precision, machine, m_r, n_r_fac, do_br=False):
def schedule_macro(
gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac, do_br=False
):
gemm_mk = specialize_precision(gemm_mk, precision)
for var, max_var in zip(("M", "N", "K"), (max_M, max_N, max_K)):
gemm_mk = gemm_mk.add_assertion(f"{var} <= {max_var}")
gemm_mk_starter = gemm_mk
gemm_mk = rename(gemm_mk, gemm_mk.name() + "_mk")
i_loop = gemm_mk.body()[0]
gemm_mk, (A_alloc, A_load, _, _) = auto_stage_mem(
gemm_mk, i_loop, "A", "packed_A", rc=True
)
gemm_mk, (B_alloc, B_load, _, _) = auto_stage_mem(
gemm_mk, i_loop, "B", "packed_B", rc=True
)
gemm_mk = bound_alloc(gemm_mk, A_alloc, (max_M, max_K))

gemm_mk = bound_alloc(gemm_mk, B_alloc, (max_K, max_N))
gemm_mk, _ = extract_subproc(gemm_mk, A_load, "A_pack_kernel")
gemm_mk, _ = extract_subproc(gemm_mk, B_load, "B_pack_kernel")
gemm_mk, _ = extract_subproc(gemm_mk, i_loop, "compute")

return gemm_mk_starter, gemm_mk
n_r = machine.vec_width(precision) * n_r_fac
j_loop = get_inner_loop(gemm_mk, i_loop)
k_loop = get_inner_loop(gemm_mk, j_loop)

j_loop = get_inner_loop(gemm, i_loop)
k_loop = get_inner_loop(gemm, j_loop)
gemm_mk = auto_stage_mem(gemm_mk, k_loop, "C", "C_tile", accum=True)
gemm_mk = lift_reduce_constant(gemm_mk, gemm_mk.forward(k_loop).expand(1, 0))
gemm_mk = inline_assign(gemm_mk, gemm_mk.find("C_tile = _ * _"))

gemm = auto_stage_mem(gemm, k_loop, "C", "C_tile", accum=True)
gemm = lift_reduce_constant(gemm, gemm.forward(k_loop).expand(1, 0))
gemm = inline_assign(gemm, gemm.find("C_tile = _ * _"))
gemm_mk = tile_loops_bottom_up(gemm_mk, i_loop, (m_r, n_r, None))
gemm_mk = apply(repeate_n(lift_scope))(
gemm_mk, gemm_mk.find_loop("k", many=True), n=2
)

gemm = tile_loops_bottom_up(gemm, i_loop, (m_r, n_r, None))
gemm = apply(repeate_n(lift_scope))(gemm, gemm.find_loop("k", many=True), n=2)

tiles = gemm.find("C_tile:_", many=True)
tiles = gemm_mk.find("C_tile:_", many=True)
names = ["_uk", "_r_uk", "_b_uk", "_br_uk"]
names = [gemm.name() + su for su in names]
names = [gemm_mk.name() + su for su in names]
if not do_br:
tiles = tiles[:-1]
names = names[:-1]
for tile, name in zip(tiles, names):
gemm = extract_and_schedule(schedule_micro)(
gemm, tile.expand(), name, precision, machine, m_r, n_r_fac
gemm_mk = extract_and_schedule(schedule_micro)(
gemm_mk, tile.expand(), name, precision, machine, m_r, n_r_fac
)

gemm = cleanup(gemm)
print(gemm)
return gemm
gemm_mk = cleanup(gemm_mk)
return gemm_mk


def schedule(main_gemm, i_loop, precision, machine):
m_r = 4
n_r_fac = 3
vw = machine.vec_width(precision)

def schedule(gemm, i_loop, precision, machine):
macro = schedule_macro(gemm, i_loop, precision, machine, 4, 3)
return macro
M_tile = m_r * 512
N_tile = n_r_fac * vw * 512
K_tile = 512

gemm_macro = schedule_macro(
gemm, precision, machine, M_tile, N_tile, K_tile, m_r, n_r_fac
)

gemm_tiled = reorder_loops(
main_gemm, i_loop.body()[0]
) # Iterate over the main gemm as (i, k, j)
gemm_tiled = tile_loops_bottom_up(gemm_tiled, i_loop, [M_tile, K_tile, N_tile])
gemm_tiled = apply(reorder_loops)(
gemm_tiled, gemm_tiled.find_loop("ki", many=True)
) # Change macrokernel loops to original order

gemm_tiled = replace_all_stmts(gemm_tiled, [gemm_macro])
macro_calls = filter(
lambda c: is_call(gemm_tiled, c, gemm_macro[1]), nlr_stmts(gemm_tiled)
)
gemm_tiled = simplify(apply(inline_proc_and_wins)(gemm_tiled, macro_calls))

gemm_tiled = apply(hoist_from_loop)(
gemm_tiled, gemm_tiled.find_loop("jo", many=True)
)
gemm_tiled = squash_buffers(gemm_tiled, gemm_tiled.find("packed_A : _", many=True))
gemm_tiled = squash_buffers(gemm_tiled, gemm_tiled.find("packed_B : _", many=True))
return gemm_tiled


variants_generator(schedule, ("f32",))(gemm, "i", globals=globals())

0 comments on commit f0aab7e

Please sign in to comment.