Skip to content

Commit

Permalink
Implement sgemm (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Nov 6, 2023
1 parent 2c902cc commit b47e8be
Show file tree
Hide file tree
Showing 14 changed files with 600 additions and 2,093 deletions.
190 changes: 168 additions & 22 deletions src/common/composed_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from exo.stdlib.scheduling import *
from exo.API_cursors import *

from introspection import get_stmt_dependencies, get_declaration
from introspection import get_stmt_dependencies, get_declaration, get_expr_dependencies


class BLAS_SchedulingError(Exception):
Expand Down Expand Up @@ -130,7 +130,7 @@ def stage_expr(proc, expr_cursors, new_name, precision="R", memory=DRAM, n_lifts
return proc


def stage_alloc(proc, alloc_cursor, n_lifts=1):
def parallelize_and_lift_alloc(proc, alloc_cursor, n_lifts=1):
"""
for i in seq(0, hi):
B1;
Expand All @@ -145,11 +145,17 @@ def stage_alloc(proc, alloc_cursor, n_lifts=1):
B2;
"""
alloc_cursor = proc.forward(alloc_cursor)
enclosing_loop = get_enclosing_loop(alloc_cursor)
proc = expand_dim(
proc, alloc_cursor, expr_to_string(enclosing_loop.hi()), enclosing_loop.name()
)
proc = lift_alloc(proc, alloc_cursor, n_lifts=n_lifts)
for i in range(n_lifts):
alloc_cursor = proc.forward(alloc_cursor)
enclosing_scope = alloc_cursor.parent()
if isinstance(enclosing_scope, ForSeqCursor):
proc = expand_dim(
proc,
alloc_cursor,
expr_to_string(enclosing_scope.hi()),
enclosing_scope.name(),
)
proc = lift_alloc(proc, alloc_cursor)
return proc


Expand All @@ -174,7 +180,7 @@ def auto_divide_loop(proc, loop_cursor, div_const, tail="guard", perfect=False):
outer_loop_cursor = proc.forward(loop_cursor)
inner_loop_cursor = outer_loop_cursor.body()[0]

if perfect == True or tail in ("cut", "cut_and_guard"):
if perfect == True or tail == "guard":
tail_loop_cursor = InvalidCursor()
else:
tail_loop_cursor = outer_loop_cursor.next()
Expand All @@ -184,6 +190,10 @@ def auto_divide_loop(proc, loop_cursor, div_const, tail="guard", perfect=False):
)


def scalar_loop_to_simd_loops(proc, loop_cursor, vec_width, memory_type, precision):
return vectorize_to_loops(proc, loop_cursor, vec_width, memory_type, precision)


def vectorize_to_loops(proc, loop_cursor, vec_width, memory_type, precision):
"""
for i in seq(0, hi):
Expand Down Expand Up @@ -245,7 +255,7 @@ def fission_stmts(proc, body, depth=1):
body_list = list(body)
for stmt in body_list[:-1]:
if isinstance(stmt, AllocCursor):
proc = stage_alloc(proc, stmt, n_lifts=depth)
proc = parallelize_and_lift_alloc(proc, stmt, n_lifts=depth)
proc = set_memory(proc, stmt, memory_type)
proc = set_precision(proc, stmt, precision)
else:
Expand Down Expand Up @@ -312,6 +322,14 @@ def get_expr_subtree_cursors(expr, stmt, alias):
return lhs_cursors + rhs_cursors + [expr]
else:
return lhs_cursors + rhs_cursors
elif isinstance(expr, UnaryMinusCursor):
return get_expr_subtree_cursors(expr.arg(), stmt, False) + [expr]
elif isinstance(expr, BuiltInFunctionCursor):
exprs = []
for arg in expr.args():
exprs = exprs + get_expr_subtree_cursors(arg, stmt, False)
exprs = exprs + [expr]
return exprs
else:
return [expr]

Expand Down Expand Up @@ -343,7 +361,7 @@ def vectorize_stmt(proc, stmt, depth=1):
alloc_cursor = forwarded_stmt.prev().prev()
if depth > 1:
proc = lift_alloc(proc, alloc_cursor, n_lifts=depth - 1)
proc = stage_alloc(proc, alloc_cursor)
proc = parallelize_and_lift_alloc(proc, alloc_cursor)

forwarded_stmt = proc.forward(stmt)
proc = fission(proc, forwarded_stmt.after(), n_lifts=depth)
Expand Down Expand Up @@ -422,7 +440,7 @@ def interleave_execution(proc, loop_cursor, interleave_factor):

for stmt in inner_loop_stmts:
if isinstance(stmt, AllocCursor):
proc = stage_alloc(proc, stmt)
proc = parallelize_and_lift_alloc(proc, stmt)

inner_loop_cursor = proc.forward(inner_loop_cursor)

Expand Down Expand Up @@ -548,7 +566,7 @@ def parallelize_reduction(
proc = set_precision(proc, alloc_cursor, precision)

outer_loop_cursor = proc.forward(outer_loop_cursor)
proc = stage_alloc(proc, outer_loop_cursor.prev().prev())
proc = parallelize_and_lift_alloc(proc, outer_loop_cursor.prev().prev())
proc = fission(proc, outer_loop_cursor.before())
proc = fission(proc, outer_loop_cursor.after())
outer_loop_cursor = proc.forward(outer_loop_cursor)
Expand Down Expand Up @@ -609,7 +627,7 @@ def interleave_outer_loop_with_inner_loop(

for stmt in middle_loop_stmts:
if isinstance(stmt, AllocCursor):
proc = stage_alloc(proc, stmt)
proc = parallelize_and_lift_alloc(proc, stmt)

inner_loop_cursor = proc.forward(inner_loop_cursor)

Expand Down Expand Up @@ -742,6 +760,9 @@ def vectorize(
precision,
tail="cut",
)
outer_loop_cursor = proc.forward(outer_loop_cursor)
proc = unroll_loop(proc, outer_loop_cursor.prev())
proc = unroll_loop(proc, outer_loop_cursor.next())
outer_loop_cursor = proc.forward(outer_loop_cursor)
inner_loop_cursor = outer_loop_cursor.body()[0]
inner_loop_cursor = proc.forward(inner_loop_cursor)
Expand All @@ -755,7 +776,7 @@ def vectorize(
return proc


def tile_loops(proc, loop_tile_pairs):
def tile_loops_top_down(proc, loop_tile_pairs):

loop_tile_pairs = [(proc.forward(i[0]), i[1]) for i in loop_tile_pairs]

Expand Down Expand Up @@ -784,13 +805,60 @@ def tile_loops(proc, loop_tile_pairs):
return proc, [proc.forward(l) for l in inner_loops]


def auto_stage_mem(proc, read_cursor, new_buff_name, n_lifts=1):
if not isinstance(read_cursor, ReadCursor):
def tile_loops_bottom_up(proc, outer_most_loop, tiles):
loop = outer_most_loop
for i in tiles[:-1]:
if not len(loop.body()) == 1:
raise BLAS_SchedulingError("All loop must have a body length of 1")
if not isinstance(loop.body()[0], ForSeqCursor):
raise BLAS_SchedulingError("Did not find a nested loop")

loops = []
loop = outer_most_loop
for i in tiles:
loops.append((loop, i))
loop = loop.body()[0]

def get_depth(loop):
if not isinstance(loop, ForSeqCursor):
return 0
return max([get_depth(i) for i in loop.body()]) + 1

def push_loop_in(proc, loop, depth):
if get_depth(loop) == depth:
return proc
count = len(loop.body())
for stmt in list(loop.body())[:-1]:
proc = fission(proc, stmt.after())
loop = proc.forward(loop)
loops = []
for i in range(count):
loops.append(loop)
loop = loop.next()
for loop in loops:
if get_depth(loop) == depth:
continue
proc = reorder_loops(proc, loop)
proc = push_loop_in(proc, proc.forward(loop), depth)
return proc

for depth, (loop, tile) in enumerate(loops[::-1]):
proc, cursors = auto_divide_loop(proc, loop, tile, tail="cut")
proc = push_loop_in(proc, cursors.inner_loop_cursor, depth + 1)
proc = push_loop_in(proc, cursors.tail_loop_cursor, depth + 1)

return proc


def auto_stage_mem(proc, cursor, new_buff_name, n_lifts=1, accum=False):
if not isinstance(cursor, (ReadCursor, ReduceCursor, AssignCursor)):
raise BLAS_SchedulingError("auto_stage_mem expects a read a cursor")

cursor = proc.forward(cursor)

lo = []
hi = []
loop = get_enclosing_loop(read_cursor)
loop = get_enclosing_loop(cursor)
loops = [loop]
for _ in range(n_lifts - 1):
loop = get_enclosing_loop(loop)
Expand All @@ -801,13 +869,13 @@ def auto_stage_mem(proc, read_cursor, new_buff_name, n_lifts=1):
loop = loops[i]
subst[loop.name()] = f"(({expr_to_string(loop.hi(), subst)})-1)"

for idx in read_cursor.idx():
for idx in cursor.idx():
hi.append(expr_to_string(idx, subst))

for key in subst:
subst[key] = 0

for idx in read_cursor.idx():
for idx in cursor.idx():
lo.append(expr_to_string(idx, subst))

def ith_idx(i):
Expand All @@ -816,6 +884,84 @@ def ith_idx(i):
else:
return f"{lo[i]}:(({hi[i]})+1)"

window = ",".join([ith_idx(i) for i in range(len(read_cursor.idx()))])
window = f"{read_cursor.name()}[{window}]"
return stage_mem(proc, loops[-1], window, new_buff_name)
window = ",".join([ith_idx(i) for i in range(len(cursor.idx()))])
window = f"{cursor.name()}[{window}]"
return stage_mem(proc, loops[-1], window, new_buff_name, accum=accum)


def ordered_stage_expr(proc, expr_cursors, new_buff_name, precision, n_lifts=1):
if not isinstance(expr_cursors, list):
expr_cursors = [expr_cursors]

if not all([isinstance(cursor, ExprCursor) for cursor in expr_cursors]):
raise BLAS_SchedulingError("auto_stage_mem expects a read a cursor")

expr_cursors = [proc.forward(c) for c in expr_cursors]
original_stmt = get_statement(expr_cursors[0])

proc = bind_expr(proc, expr_cursors, new_buff_name, cse=True)
original_stmt = proc.forward(original_stmt)
assign_cursor = original_stmt.prev()
alloc_cursor = assign_cursor.prev()
expr_cursor = assign_cursor.rhs()
deps = list(get_expr_dependencies(expr_cursor))

assert isinstance(assign_cursor, AssignCursor)
assert isinstance(alloc_cursor, AllocCursor)

anchor_stmt = assign_cursor

def hoist_as_loop(proc, stmt_cursor):
stmt_cursor = proc.forward(stmt_cursor)
while not isinstance(stmt_cursor.prev(), InvalidCursor):
proc = reorder_stmts(proc, stmt_cursor.expand(1, 0))
stmt_cursor = proc.forward(stmt_cursor)

proc = fission(proc, stmt_cursor.after())

return proc

for i in range(n_lifts):
parent = anchor_stmt.parent()

if not isinstance(parent, ForSeqCursor):
raise BLAS_SchedulingError("Not implemented yet")
if parent.name() in deps:
proc = parallelize_and_lift_alloc(proc, alloc_cursor)
else:
proc = lift_alloc(proc, alloc_cursor)

proc = hoist_as_loop(proc, anchor_stmt)
anchor_stmt = proc.forward(anchor_stmt)
anchor_stmt = anchor_stmt.parent()

alloc_cursor = proc.forward(alloc_cursor)
loop_nest = alloc_cursor.next()

def try_removing_loops(proc, loop):
child_stmt = loop.body()[0]
if isinstance(child_stmt, ForSeqCursor):
proc = try_removing_loops(proc, child_stmt)
try:
proc = remove_loop(proc, loop)
except:
pass
return proc

proc = try_removing_loops(proc, loop_nest)
alloc_cursor = proc.forward(alloc_cursor)
proc = set_precision(proc, alloc_cursor, precision)
scopes_nest = alloc_cursor.next()

def lift_all_ifs(proc, scope, depth=0):
if isinstance(scope, IfCursor):
for i in range(depth):
proc = lift_scope(proc, scope)
child_stmt = scope.body()[0]
if isinstance(child_stmt, (ForSeqCursor, IfCursor)):
proc = lift_all_ifs(proc, child_stmt, depth + 1)
return proc

proc = lift_all_ifs(proc, scopes_nest)

return proc
2 changes: 0 additions & 2 deletions src/level1/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def schedule_dot_stride_1(dot, params):
dot = generate_stride_1_proc(dot, params.precision)
main_loop = dot.find_loop("i")
dot = blas_vectorize(dot, main_loop, params)
dot = unroll_loop(dot, dot.find_loop("ioi"))
dot = unroll_loop(dot, dot.find_loop("ioi"))
return simplify(dot)


Expand Down
Loading

0 comments on commit b47e8be

Please sign in to comment.