Skip to content

Commit

Permalink
Implement filter_cursors op (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Mar 18, 2024
1 parent ec98343 commit c1a05e0
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/common/codegen_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def is_numeric(proc, s, *arg):

set_type = set_numerics if all_buffs else set_R_type
proc = apply(set_type)(proc, proc.args(), precision)
proc = make_pass(set_type)(proc, proc.body(), precision)
proc = make_pass(set_type, nlr_stmts)(proc, proc.body(), precision)
return proc


Expand Down
12 changes: 10 additions & 2 deletions src/common/higher_order.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from exo.stdlib.scheduling import *

from exceptions import *
from inspection import *


def attempt(op, errs=exo_exceptions):
Expand Down Expand Up @@ -37,7 +36,7 @@ def rewrite(proc, *args, **kwargs):
return rewrite


def make_pass(op, trav_start=nlr_stmts):
def make_pass(op, trav_start):
def rewrite(proc, block=InvalidCursor(), *args, **kwargs):
stmts = trav_start(proc, block)
return apply(op)(proc, stmts, *args, **kwargs)
Expand Down Expand Up @@ -90,6 +89,15 @@ def rewrite(proc, block, subproc_name, *args, rc=False, **kwargs):
return rewrite


def filter_cursors(op):
def filter_c(proc, cursors, *args, **kwargs):
for c in cursors:
if op(proc, c, *args, **kwargs):
yield c

return filter_c


__all__ = [
"apply",
"attempt",
Expand Down
27 changes: 22 additions & 5 deletions src/common/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from exo.stdlib.analysis import *

from exceptions import *
from higher_order import filter_cursors


def get_children(proc, cursor=InvalidCursor(), lr=True):
Expand Down Expand Up @@ -85,8 +86,7 @@ def generator():


def get_numeric_children(proc, cursor=InvalidCursor()):
check = lambda c: hasattr(c, "type") and c.type().is_numeric()
yield from filter(check, get_children(proc, cursor))
yield from filter_cursors(is_type_numeric)(proc, get_children(proc, cursor))


def _get_cursors(
Expand Down Expand Up @@ -223,6 +223,16 @@ def is_single_stmt_loop(proc, loop):
return loop_body_len(proc, loop) == 1


def is_type_numeric(proc, c):
c = proc.forward(c)
return hasattr(c, "type") and c.type().is_numeric()


def is_alloc(proc, a):
a = proc.forward(a)
return isinstance(a, AllocCursor)


def get_enclosing_scope(proc, cursor, scope_type):
cursor = proc.forward(cursor)
cursor = cursor.parent()
Expand Down Expand Up @@ -293,7 +303,7 @@ def get_parents(proc, stmt):

def get_nth_inner_loop(proc, loop, n):
loop = proc.forward(loop)
inner_loops = list(filter(lambda s: is_loop(proc, s), loop.body()))
inner_loops = list(filter_cursors(is_loop)(proc, loop.body()))
if n >= len(inner_loops):
raise BLAS_SchedulingError(
f"Expected exactly at least {n + 1} loops, found {len(inner_loops)}"
Expand Down Expand Up @@ -371,8 +381,15 @@ def is_write(proc, write):
return is_reduce(proc, write) or is_assign(proc, write)


def is_access(proc, access):
return is_read(proc, access) or is_write(proc, access)
def is_access(proc, access, name=None):
return (is_read(proc, access) or is_write(proc, access)) and (
name is None or access.name() == name
)


def is_window_stmt(proc, window):
window = proc.forward(window)
return isinstance(window, WindowStmtCursor)


def is_unary_minus(proc, expr):
Expand Down
48 changes: 25 additions & 23 deletions src/common/stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def parallelize_allocs(proc, cursor):
f"Got type {type(cursor)}, expected {ForCursor} or {IfCursor}"
)

allocs = filter(lambda s: isinstance(s, AllocCursor), nlr_stmts(proc, cursor))
allocs = filter_cursors(is_alloc)(proc, nlr_stmts(proc, cursor))
func = lambda proc, alloc: parallelize_and_lift_alloc(
proc, alloc, get_distance(proc, alloc, cursor)
)
Expand Down Expand Up @@ -109,7 +109,7 @@ def rewrite(proc, loop, factor=None, par_reduce=False, memory=DRAM, tail="cut"):
if par_reduce:
proc = parallelize_all_reductions(proc, outer, memory=memory, unroll=True)
loop = proc.forward(outer).body()[0]
allocs = filter(lambda s: isinstance(s, AllocCursor), loop.body())
allocs = filter_cursors(is_alloc)(proc, loop.body())
proc = apply(parallelize_and_lift_alloc)(proc, allocs)

stmts = list(proc.forward(loop).body())
Expand Down Expand Up @@ -384,7 +384,7 @@ def rewrite(proc, s):
break
return parallelize_reduction(proc, s, factor, memory, nth_loop, unroll)

return make_pass(attempt(rewrite))(proc, loop.body())
return make_pass(attempt(rewrite), nlr_stmts)(proc, loop.body())


def unroll_and_jam(proc, loop, factor, unroll=(True, True, True)):
Expand Down Expand Up @@ -478,7 +478,7 @@ def dfs(proc, cursor, n_lifts=0):
if n_lifts and not is_end_of_body(proc, cursor):
proc = fission(proc, cursor.after(), n_lifts)
children = get_children(proc, cursor)
children = filter(lambda s: isinstance(s, StmtCursor), children)
children = filter_cursors(is_stmt)(proc, children)
return apply(dfs)(proc, children, n_lifts + 1)

proc = parallelize_allocs(proc, cursor)
Expand Down Expand Up @@ -608,7 +608,7 @@ def get_depth(loop):

def push_loop_in(proc, loop, depth, size=None):
loop = proc.forward(loop)
allocs = list(filter(lambda a: isinstance(a, AllocCursor), loop.body()))
allocs = list(filter_cursors(is_alloc)(proc, loop.body()))
proc = apply(attempt(parallelize_and_lift_alloc))(proc, allocs)
if const_allocs and size:
proc = apply(lambda p, a: resize_dim(p, a, 0, size, 0))(proc, allocs)
Expand Down Expand Up @@ -659,10 +659,8 @@ def auto_stage_mem(proc, block, buff, new_buff_name=None, accum=False, rc=False)
block = block.as_block()

block_nodes = list(lrn(proc, block))
block_loops = list(filter(lambda s: is_loop(proc, s), block_nodes))
block_accesses = filter(
lambda s: is_access(proc, s) and s.name() == buff, block_nodes
)
block_loops = list(filter_cursors(is_loop)(proc, block_nodes))
block_accesses = filter_cursors(is_access)(proc, block_nodes, name=buff)

def eval_rng(expr, env):
expr = proc.forward(expr)
Expand Down Expand Up @@ -850,7 +848,7 @@ def _eliminate_dead_code_pruned(proc, s):
return eliminate_dead_code(proc, s)


dce = make_pass(attempt(_eliminate_dead_code_pruned))
dce = make_pass(attempt(_eliminate_dead_code_pruned), nlr_stmts)


def unroll_buffers(proc, block=InvalidCursor(), mem=None):
Expand All @@ -867,7 +865,7 @@ def rewrite(proc, alloc):
return proc

while True:
new_proc = make_pass(rewrite)(proc, block)
new_proc = make_pass(rewrite, nlr_stmts)(proc, block)
if new_proc == proc:
break
proc = new_proc
Expand Down Expand Up @@ -963,15 +961,15 @@ def stage(proc, exprs):
return apply(stage)(proc, children)

block = proc.forward(block)
allocs = filter(lambda s: isinstance(s, AllocCursor), nlr_stmts(proc, block))
allocs = filter_cursors(is_alloc)(proc, nlr_stmts(proc, block))
proc = apply(set_memory)(proc, allocs, memory)
proc = make_pass(attempt(unfold_reduce))(proc, block)
assigns = filter(lambda s: isinstance(s, AssignCursor), lrn_stmts(proc, block))
proc = make_pass(attempt(unfold_reduce), nlr_stmts)(proc, block)
assigns = filter_cursors(is_assign)(proc, lrn_stmts(proc, block))
exprs = [assign.rhs() for assign in assigns]
proc = apply(stage)(proc, exprs)
# TODO: uncomment once bug in Exo is fixed
# proc = inline_copies(proc, block)
proc = make_pass(attempt(fold_into_reduce))(proc, block)
proc = make_pass(attempt(fold_into_reduce), nlr_stmts)(proc, block)
proc = dealiasing_pass(proc, block)
return proc

Expand All @@ -986,7 +984,7 @@ def check_call_site(proc, call_cursor):
###################################################################
env = {}
obs_stmts = get_observed_stmts(call_cursor)
allocs = filter(lambda s: isinstance(s, AllocCursor), obs_stmts)
allocs = filter_cursors(is_alloc)(proc, obs_stmts)
for s in list(proc.args()) + list(allocs):
if s.type().is_numeric():
env[s.name()] = (s.mem(), s.type())
Expand Down Expand Up @@ -1218,8 +1216,10 @@ def cse(proc, block, precision):
nodes = list(lrn(proc, block))

# First do CSE on the buffer accesses
accesses = filter(lambda c: is_access(proc, c), nodes)
accesses = filter(lambda c: is_stmt(proc, c) or c.type().is_numeric(), accesses)
accesses = filter_cursors(is_access)(proc, nodes)
accesses = filter_cursors(lambda p, c: is_stmt(p, c) or c.type().is_numeric())(
proc, accesses
)

buff_map = {}

Expand All @@ -1236,7 +1236,7 @@ def cse(proc, block, precision):
return proc


inline_copies = make_pass(predicate(attempt(inline_assign), is_copy))
inline_copies = make_pass(predicate(attempt(inline_assign), is_copy), nlr_stmts)


def dealias(proc, stmt):
Expand All @@ -1247,8 +1247,10 @@ def dealias(proc, stmt):

nodes = list(lrn(proc, stmt.rhs()))

accesses = filter(lambda c: is_access(proc, c), nodes)
accesses = filter(lambda c: is_stmt(proc, c) or c.type().is_numeric(), accesses)
accesses = filter_cursors(is_access)(proc, nodes)
accesses = filter_cursors(lambda p, c: is_stmt(p, c) or c.type().is_numeric())(
proc, accesses
)

buff_map = {}
for access in accesses:
Expand All @@ -1266,7 +1268,7 @@ def dealias(proc, stmt):
return proc


dealiasing_pass = make_pass(attempt(dealias, errs=(TypeError,)))
dealiasing_pass = make_pass(attempt(dealias, errs=(TypeError,)), nlr_stmts)


def round_loop(proc, loop, factor, up=True):
Expand Down Expand Up @@ -1323,7 +1325,7 @@ def inline_proc_and_wins(proc, call, rc=False):
call = proc.forward(call)
proc = inline(proc, call)
block = proc.forward(call.as_block())
windows = filter(lambda s: isinstance(s, WindowStmtCursor), block)
windows = filter_cursors(is_window_stmt)(proc, block)
proc = apply(inline_window)(proc, windows)
if not rc:
return proc
Expand Down
14 changes: 5 additions & 9 deletions src/level3/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,13 @@ def schedule_macro(

packed_A_shape = ((0, max_M // m_r), (1, max_K), (0, m_r))
gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "A", packed_A_shape, "packed_A", rc=1)
gemm_mk, _ = extract_subproc(
gemm_mk, cursors.load, "A_pack_kernel"
) # TODO: Schedule packing kernel
# TODO: Schedule packing kernel
gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, "A_pack_kernel")

packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r))
gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "B", packed_B_shape, "packed_B", rc=1)
gemm_mk, _ = extract_subproc(
gemm_mk, cursors.load, "B_pack_kernel"
) # TODO: Schedule packing kernel
# TODO: Schedule packing kernel
gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, "B_pack_kernel")

gemm_mk, _ = extract_subproc(gemm_mk, i_loop, "compute")
return gemm_mk_starter, gemm_mk
Expand Down Expand Up @@ -152,9 +150,7 @@ def schedule(main_gemm, i_loop, precision, machine):
) # Change macrokernel loops to original order

gemm_tiled = replace_all_stmts(gemm_tiled, [gemm_macro])
macro_calls = filter(
lambda c: is_call(gemm_tiled, c, gemm_macro[1]), nlr_stmts(gemm_tiled)
)
macro_calls = filter_cursors(is_call)(gemm_tiled, nlr_stmts(gemm_tiled))
gemm_tiled = simplify(apply(inline_proc_and_wins)(gemm_tiled, macro_calls))

gemm_tiled = apply(hoist_from_loop)(
Expand Down

0 comments on commit c1a05e0

Please sign in to comment.