Skip to content

Commit

Permalink
move dispatching rules to global file
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Nov 30, 2023
1 parent ad8dca1 commit 28b1aac
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
17 changes: 2 additions & 15 deletions compiler/transforms/dispatch_regions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xdsl.dialects import builtin, memref, func, scf, linalg
from xdsl.dialects import builtin, func, scf
from xdsl.ir.core import Operation, Block
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
Expand All @@ -10,6 +10,7 @@
op_type_rewrite_pattern,
)
from xdsl.traits import SymbolTable
from compiler.util.dispatching_rules import dispatch_to_compute, dispatch_to_dm


class DispatchRegionsRewriter(RewritePattern):
Expand Down Expand Up @@ -58,20 +59,6 @@ def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable):

return changes_made

def dispatch_to_dm(op):
"""Rule to dispatch operations to the dm core:
for now, this is only memref copy operations"""
if isinstance(op, memref.CopyOp):
return True
return False

def dispatch_to_compute(op):
"""Rule to dispatch operations to the dm core:
for now, this is only linalg generic operations"""
if isinstance(op, linalg.Generic):
return True
return False

# find root module op of func op:
# this is necessary to declare the external functions
# such as is_dm_core and is_compute core at the correct place
Expand Down
14 changes: 5 additions & 9 deletions compiler/transforms/insert_sync_barrier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xdsl.dialects import builtin, memref, linalg
from xdsl.dialects import builtin
from compiler.dialects import snax
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
Expand All @@ -8,6 +8,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from compiler.util.dispatching_rules import dispatch_to_compute, dispatch_to_dm


class InsertSyncBarrierRewriter(RewritePattern):
Expand Down Expand Up @@ -40,15 +41,10 @@ def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter)
# is used on another core - if yes, there must be a synchronisation
# barrier between the two ops

# basic dispatching rules for now
def check_core(op):
if isinstance(op, memref.CopyOp):
return "dm"
if isinstance(op, linalg.Generic):
return "compute"
return "global"
if dispatch_to_dm(op) and not dispatch_to_dm(op_use.operation):
ops_to_sync.append(op_use.operation)

if check_core(op) != "global" and check_core(op) != check_core(
if dispatch_to_compute(op) and not dispatch_to_compute(
op_use.operation
):
ops_to_sync.append(op_use.operation)
Expand Down
17 changes: 17 additions & 0 deletions compiler/util/dispatching_rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from xdsl.dialects import memref, linalg


def dispatch_to_dm(op):
"""Rule to dispatch operations to the dm core:
for now, this is only memref copy operations"""
if isinstance(op, memref.CopyOp):
return True
return False


def dispatch_to_compute(op):
"""Rule to dispatch operations to the dm core:
for now, this is only linalg generic operations"""
if isinstance(op, linalg.Generic):
return True
return False

0 comments on commit 28b1aac

Please sign in to comment.