Skip to content

Commit

Permalink
Merge branch 'main' into move_gemmini
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Mar 17, 2024
2 parents a641720 + cebbddf commit 3cef652
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 369 deletions.
8 changes: 6 additions & 2 deletions src/common/codegen_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from higher_order import *
import exo_blas_config as C
from perf_features import *
from stdlib import *


def specialize_precision(proc, precision, all_buffs=True):
Expand Down Expand Up @@ -109,7 +110,7 @@ def export_perf_features(kernel_name, perf_features):
json.dump(perf_features, f, sort_keys=True, indent=4, separators=(",", ": "))


def variants_generator(blas_op):
def variants_generator(blas_op, opt_precisions=("f32", "f64")):
def generate(proc, loop_name, *args, globals=None, **kwargs):
perf_features = {}
for precision in ("f32", "f64"):
Expand All @@ -125,7 +126,10 @@ def generate(proc, loop_name, *args, globals=None, **kwargs):
loop = stride_1.find_loop(loop_name)
algorithm = get_perf_features(stride_1)

stride_1 = blas_op(stride_1, loop, precision, C.Machine, *args, **kwargs)
if precision in opt_precisions:
stride_1 = blas_op(
stride_1, loop, precision, C.Machine, *args, **kwargs
)
stride_1 = bind_builtins_args(stride_1, stride_1.body(), precision)
scheduled = get_perf_features(stride_1)

Expand Down
37 changes: 36 additions & 1 deletion src/common/higher_order.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from exo.stdlib.scheduling import *

from exceptions import *
from inspection import *

Expand Down Expand Up @@ -64,4 +66,37 @@ def rewrite(p, *args, **kwargs):
return rewrite


__all__ = ["apply", "attempt", "make_pass", "lift_rc", "repeate", "predicate"]
def repeate_n(op):
def rewrite(p, *args, n=1, **kwargs):
for i in range(n):
p = op(p, *args, **kwargs)
return p

return rewrite


def extract_and_schedule(op):
def rewrite(proc, block, subproc_name, *args, rc=False, **kwargs):
block = proc.forward(block)
block = block.as_block()
proc, subproc = extract_subproc(proc, block, subproc_name)
subproc_sched = op(subproc, *args, **kwargs)
call = proc.forward(block)[0]
proc = call_eqv(proc, call, subproc)
if not rc:
return proc
return proc, (subproc, subproc_sched)

return rewrite


__all__ = [
"apply",
"attempt",
"make_pass",
"lift_rc",
"repeate",
"predicate",
"extract_and_schedule",
"repeate_n",
]
21 changes: 19 additions & 2 deletions src/common/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def stmt_children(stmt):
elif isinstance(stmt, CallCursor):
yield from stmt.args()
elif isinstance(stmt, WindowStmtCursor):
yield from stmt.idx()
yield stmt.winexpr()
elif isinstance(stmt, AssignConfigCursor):
yield stmt.rhs()
elif isinstance(stmt, PassCursor):
Expand Down Expand Up @@ -380,6 +380,23 @@ def is_unary_minus(proc, expr):
return isinstance(expr, UnaryMinusCursor)


def is_call(proc, call, subproc=None):
call = proc.forward(call)
return isinstance(call, CallCursor) and (
subproc is None or call.subproc() == subproc
)


def is_invalid(proc, inv):
if isinstance(inv, InvalidCursor):
return True
try:
inv = proc.forward(inv)
return False
except InvalidCursorError:
return True


def is_start_of_body(proc, stmt):
stmt = proc.forward(stmt)
return isinstance(stmt.prev(), InvalidCursor)
Expand All @@ -394,7 +411,7 @@ def get_depth(proc, cursor):
cursor = proc.forward(cursor)

depth = 1
while not isinstance(cursor, InvalidCursor):
while not isinstance(cursor.parent(), InvalidCursor):
cursor = cursor.parent()
depth += 1
return depth
Expand Down
59 changes: 18 additions & 41 deletions src/common/rc_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ def __iter__(self):
yield self.tail_loop


def divide_loop_(proc, loop_cursor, div_const, tail="guard", perfect=False, rc=False):
def divide_loop_(proc, loop_cursor, div_const, tail="guard", rc=False):
loop_cursor = proc.forward(loop_cursor)
loop_iter = loop_cursor.name()
perfect = tail == "perfect"
if tail == "perfect":
tail = "cut"
perfect = True
proc = divide_loop(
proc,
loop_cursor,
Expand Down Expand Up @@ -80,28 +84,33 @@ def divide_loop_(proc, loop_cursor, div_const, tail="guard", perfect=False, rc=F
@dataclass
class stage_mem_cursors:
alloc: AllocCursor
load_stage: Cursor
load: Cursor
block: BlockCursor
store_stage: Cursor
store: Cursor

def __iter__(self):
yield self.alloc
yield self.load
yield self.block
yield self.store


def stage_mem_(proc, block, buff, new_buff_name, accum=False, rc=False):
if not isinstance(block, BlockCursor):
block = proc.forward(block)
block = block.as_block()
block = proc.forward(block)
block = block.as_block()

block_first = block[0]
block_last = block[-1]
proc = stage_mem(proc, block, buff, new_buff_name, accum)
block_first = proc.forward(block_first)
block_last = proc.forward(block_last)
alloc = block_first.prev().prev()
load_stage = block_first.prev()
load = block_first.prev()
block = block_first.as_block().expand(0, len(block) - 1)
store_stage = block_last.next()
store = block_last.next()
if not rc:
return proc
return proc, stage_mem_cursors(alloc, load_stage, block, store_stage)
return proc, stage_mem_cursors(alloc, load, block, store)


@dataclass
Expand All @@ -123,35 +132,3 @@ def cut_loop_(proc, loop, expr, rc=False):
loop1 = proc.forward(loop)
loop2 = loop1.next()
return proc, cut_loop_cursors(loop1, loop2)


@dataclass
class specialize_cursors:
if_stmt: Cursor

def __iter__(self):
yield self.if_stmt


def specialize_(proc, stmt, cond, rc=False):
stmt = proc.forward(stmt)
parent = stmt.parent()
index = get_index_in_body(proc, stmt)
proc = specialize(proc, stmt, cond)
if not rc:
return proc
is_else = False
if (
isinstance(parent, IfCursor)
and not isinstance(parent.orelse(), InvalidCursor)
and index < len(parent.orelse())
and parent.orelse()[index] == stmt
):
is_else = True
if not isinstance(parent, InvalidCursor):
parent = proc.forward(parent)
else:
parent = proc

if_stmt = parent.body()[index] if not is_else else parent.orelse()[index]
return proc, specialize_cursors(if_stmt)
Loading

0 comments on commit 3cef652

Please sign in to comment.