diff --git a/src/common/codegen_helpers.py b/src/common/codegen_helpers.py index 55fc694..c9bdae7 100644 --- a/src/common/codegen_helpers.py +++ b/src/common/codegen_helpers.py @@ -39,7 +39,7 @@ def is_numeric(proc, s, *arg): set_type = set_numerics if all_buffs else set_R_type proc = apply(set_type)(proc, proc.args(), precision) - proc = make_pass(set_type)(proc, proc.body(), precision) + proc = make_pass(set_type, nlr_stmts)(proc, proc.body(), precision) return proc diff --git a/src/common/higher_order.py b/src/common/higher_order.py index 7e4b8f2..9c64aa2 100644 --- a/src/common/higher_order.py +++ b/src/common/higher_order.py @@ -1,7 +1,6 @@ from exo.stdlib.scheduling import * from exceptions import * -from inspection import * def attempt(op, errs=exo_exceptions): @@ -37,7 +36,7 @@ def rewrite(proc, *args, **kwargs): return rewrite -def make_pass(op, trav_start=nlr_stmts): +def make_pass(op, trav_start): def rewrite(proc, block=InvalidCursor(), *args, **kwargs): stmts = trav_start(proc, block) return apply(op)(proc, stmts, *args, **kwargs) @@ -90,6 +89,15 @@ def rewrite(proc, block, subproc_name, *args, rc=False, **kwargs): return rewrite +def filter_cursors(op): + def filter_c(proc, cursors, *args, **kwargs): + for c in cursors: + if op(proc, c, *args, **kwargs): + yield c + + return filter_c + + __all__ = [ "apply", "attempt", diff --git a/src/common/inspection.py b/src/common/inspection.py index b26378d..4cb175b 100644 --- a/src/common/inspection.py +++ b/src/common/inspection.py @@ -3,6 +3,7 @@ from exo.stdlib.analysis import * from exceptions import * +from higher_order import filter_cursors def get_children(proc, cursor=InvalidCursor(), lr=True): @@ -85,8 +86,7 @@ def generator(): def get_numeric_children(proc, cursor=InvalidCursor()): - check = lambda c: hasattr(c, "type") and c.type().is_numeric() - yield from filter(check, get_children(proc, cursor)) + yield from filter_cursors(is_type_numeric)(proc, get_children(proc, cursor)) def _get_cursors( @@ -223,6 +223,16 @@ def is_single_stmt_loop(proc, loop): return loop_body_len(proc, loop) == 1 +def is_type_numeric(proc, c): + c = proc.forward(c) + return hasattr(c, "type") and c.type().is_numeric() + + +def is_alloc(proc, a): + a = proc.forward(a) + return isinstance(a, AllocCursor) + + def get_enclosing_scope(proc, cursor, scope_type): cursor = proc.forward(cursor) cursor = cursor.parent() @@ -293,7 +303,7 @@ def get_parents(proc, stmt): def get_nth_inner_loop(proc, loop, n): loop = proc.forward(loop) - inner_loops = list(filter(lambda s: is_loop(proc, s), loop.body())) + inner_loops = list(filter_cursors(is_loop)(proc, loop.body())) if n >= len(inner_loops): raise BLAS_SchedulingError( f"Expected exactly at least {n + 1} loops, found {len(inner_loops)}" @@ -371,8 +381,15 @@ def is_write(proc, write): return is_reduce(proc, write) or is_assign(proc, write) -def is_access(proc, access): - return is_read(proc, access) or is_write(proc, access) +def is_access(proc, access, name=None): + return (is_read(proc, access) or is_write(proc, access)) and ( + name is None or access.name() == name + ) + + +def is_window_stmt(proc, window): + window = proc.forward(window) + return isinstance(window, WindowStmtCursor) def is_unary_minus(proc, expr): diff --git a/src/common/stdlib.py b/src/common/stdlib.py index 0f40c54..3810639 100644 --- a/src/common/stdlib.py +++ b/src/common/stdlib.py @@ -74,7 +74,7 @@ def parallelize_allocs(proc, cursor): f"Got type {type(cursor)}, expected {ForCursor} or {IfCursor}" ) - allocs = filter(lambda s: isinstance(s, AllocCursor), nlr_stmts(proc, cursor)) + allocs = filter_cursors(is_alloc)(proc, nlr_stmts(proc, cursor)) func = lambda proc, alloc: parallelize_and_lift_alloc( proc, alloc, get_distance(proc, alloc, cursor) ) @@ -109,7 +109,7 @@ def rewrite(proc, loop, factor=None, par_reduce=False, memory=DRAM, tail="cut"): if par_reduce: proc = parallelize_all_reductions(proc, outer, memory=memory, unroll=True) loop = proc.forward(outer).body()[0] - allocs = filter(lambda s: isinstance(s, AllocCursor), loop.body()) + allocs = filter_cursors(is_alloc)(proc, loop.body()) proc = apply(parallelize_and_lift_alloc)(proc, allocs) stmts = list(proc.forward(loop).body()) @@ -384,7 +384,7 @@ def rewrite(proc, s): break return parallelize_reduction(proc, s, factor, memory, nth_loop, unroll) - return make_pass(attempt(rewrite))(proc, loop.body()) + return make_pass(attempt(rewrite), nlr_stmts)(proc, loop.body()) def unroll_and_jam(proc, loop, factor, unroll=(True, True, True)): @@ -478,7 +478,7 @@ def dfs(proc, cursor, n_lifts=0): if n_lifts and not is_end_of_body(proc, cursor): proc = fission(proc, cursor.after(), n_lifts) children = get_children(proc, cursor) - children = filter(lambda s: isinstance(s, StmtCursor), children) + children = filter_cursors(is_stmt)(proc, children) return apply(dfs)(proc, children, n_lifts + 1) proc = parallelize_allocs(proc, cursor) @@ -608,7 +608,7 @@ def get_depth(loop): def push_loop_in(proc, loop, depth, size=None): loop = proc.forward(loop) - allocs = list(filter(lambda a: isinstance(a, AllocCursor), loop.body())) + allocs = list(filter_cursors(is_alloc)(proc, loop.body())) proc = apply(attempt(parallelize_and_lift_alloc))(proc, allocs) if const_allocs and size: proc = apply(lambda p, a: resize_dim(p, a, 0, size, 0))(proc, allocs) @@ -659,10 +659,8 @@ def auto_stage_mem(proc, block, buff, new_buff_name=None, accum=False, rc=False) block = block.as_block() block_nodes = list(lrn(proc, block)) - block_loops = list(filter(lambda s: is_loop(proc, s), block_nodes)) - block_accesses = filter( - lambda s: is_access(proc, s) and s.name() == buff, block_nodes - ) + block_loops = list(filter_cursors(is_loop)(proc, block_nodes)) + block_accesses = filter_cursors(is_access)(proc, block_nodes, name=buff) def eval_rng(expr, env): expr = proc.forward(expr) @@ -850,7 +848,7 @@ def _eliminate_dead_code_pruned(proc, s): return eliminate_dead_code(proc, s) -dce = make_pass(attempt(_eliminate_dead_code_pruned)) +dce = make_pass(attempt(_eliminate_dead_code_pruned), nlr_stmts) def unroll_buffers(proc, block=InvalidCursor(), mem=None): @@ -867,7 +865,7 @@ def rewrite(proc, alloc): return proc while True: - new_proc = make_pass(rewrite)(proc, block) + new_proc = make_pass(rewrite, nlr_stmts)(proc, block) if new_proc == proc: break proc = new_proc @@ -963,15 +961,15 @@ def stage(proc, exprs): return apply(stage)(proc, children) block = proc.forward(block) - allocs = filter(lambda s: isinstance(s, AllocCursor), nlr_stmts(proc, block)) + allocs = filter_cursors(is_alloc)(proc, nlr_stmts(proc, block)) proc = apply(set_memory)(proc, allocs, memory) - proc = make_pass(attempt(unfold_reduce))(proc, block) - assigns = filter(lambda s: isinstance(s, AssignCursor), lrn_stmts(proc, block)) + proc = make_pass(attempt(unfold_reduce), nlr_stmts)(proc, block) + assigns = filter_cursors(is_assign)(proc, lrn_stmts(proc, block)) exprs = [assign.rhs() for assign in assigns] proc = apply(stage)(proc, exprs) # TODO: uncomment once bug in Exo is fixed # proc = inline_copies(proc, block) - proc = make_pass(attempt(fold_into_reduce))(proc, block) + proc = make_pass(attempt(fold_into_reduce), nlr_stmts)(proc, block) proc = dealiasing_pass(proc, block) return proc @@ -986,7 +984,7 @@ def check_call_site(proc, call_cursor): ################################################################### env = {} obs_stmts = get_observed_stmts(call_cursor) - allocs = filter(lambda s: isinstance(s, AllocCursor), obs_stmts) + allocs = filter_cursors(is_alloc)(proc, obs_stmts) for s in list(proc.args()) + list(allocs): if s.type().is_numeric(): env[s.name()] = (s.mem(), s.type()) @@ -1218,8 +1216,10 @@ def cse(proc, block, precision): nodes = list(lrn(proc, block)) # First do CSE on the buffer accesses - accesses = filter(lambda c: is_access(proc, c), nodes) - accesses = filter(lambda c: is_stmt(proc, c) or c.type().is_numeric(), accesses) + accesses = filter_cursors(is_access)(proc, nodes) + accesses = filter_cursors(lambda p, c: is_stmt(p, c) or c.type().is_numeric())( + proc, accesses + ) buff_map = {} @@ -1236,7 +1236,7 @@ def cse(proc, block, precision): return proc -inline_copies = make_pass(predicate(attempt(inline_assign), is_copy)) +inline_copies = make_pass(predicate(attempt(inline_assign), is_copy), nlr_stmts) def dealias(proc, stmt): @@ -1247,8 +1247,10 @@ def dealias(proc, stmt): nodes = list(lrn(proc, stmt.rhs())) - accesses = filter(lambda c: is_access(proc, c), nodes) - accesses = filter(lambda c: is_stmt(proc, c) or c.type().is_numeric(), accesses) + accesses = filter_cursors(is_access)(proc, nodes) + accesses = filter_cursors(lambda p, c: is_stmt(p, c) or c.type().is_numeric())( + proc, accesses + ) buff_map = {} for access in accesses: @@ -1266,7 +1268,7 @@ def dealias(proc, stmt): return proc -dealiasing_pass = make_pass(attempt(dealias, errs=(TypeError,))) +dealiasing_pass = make_pass(attempt(dealias, errs=(TypeError,)), nlr_stmts) def round_loop(proc, loop, factor, up=True): @@ -1323,7 +1325,7 @@ def inline_proc_and_wins(proc, call, rc=False): call = proc.forward(call) proc = inline(proc, call) block = proc.forward(call.as_block()) - windows = filter(lambda s: isinstance(s, WindowStmtCursor), block) + windows = filter_cursors(is_window_stmt)(proc, block) proc = apply(inline_window)(proc, windows) if not rc: return proc diff --git a/src/level3/gemm.py b/src/level3/gemm.py index d04a675..ea76444 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -90,15 +90,13 @@ def schedule_macro( packed_A_shape = ((0, max_M // m_r), (1, max_K), (0, m_r)) gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "A", packed_A_shape, "packed_A", rc=1) - gemm_mk, _ = extract_subproc( - gemm_mk, cursors.load, "A_pack_kernel" - ) # TODO: Schedule packing kernel + # TODO: Schedule packing kernel + gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, "A_pack_kernel") packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r)) gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "B", packed_B_shape, "packed_B", rc=1) - gemm_mk, _ = extract_subproc( - gemm_mk, cursors.load, "B_pack_kernel" - ) # TODO: Schedule packing kernel + # TODO: Schedule packing kernel + gemm_mk, _ = extract_subproc(gemm_mk, cursors.load, "B_pack_kernel") gemm_mk, _ = extract_subproc(gemm_mk, i_loop, "compute") return gemm_mk_starter, gemm_mk @@ -152,9 +150,7 @@ def schedule(main_gemm, i_loop, precision, machine): ) # Change macrokernel loops to original order gemm_tiled = replace_all_stmts(gemm_tiled, [gemm_macro]) - macro_calls = filter( - lambda c: is_call(gemm_tiled, c, gemm_macro[1]), nlr_stmts(gemm_tiled) - ) + macro_calls = filter_cursors(is_call)(gemm_tiled, nlr_stmts(gemm_tiled)) gemm_tiled = simplify(apply(inline_proc_and_wins)(gemm_tiled, macro_calls)) gemm_tiled = apply(hoist_from_loop)(