From 28b1aac4c925d232f796d9e2fc0e39615d210f10 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Thu, 30 Nov 2023 14:46:09 +0100 Subject: [PATCH] move dispatching rules to global file --- compiler/transforms/dispatch_regions.py | 17 ++--------------- compiler/transforms/insert_sync_barrier.py | 14 +++++--------- compiler/util/dispatching_rules.py | 17 +++++++++++++++++ 3 files changed, 24 insertions(+), 24 deletions(-) create mode 100644 compiler/util/dispatching_rules.py diff --git a/compiler/transforms/dispatch_regions.py b/compiler/transforms/dispatch_regions.py index e17657a7..1a316ef2 100644 --- a/compiler/transforms/dispatch_regions.py +++ b/compiler/transforms/dispatch_regions.py @@ -1,4 +1,4 @@ -from xdsl.dialects import builtin, memref, func, scf, linalg +from xdsl.dialects import builtin, func, scf from xdsl.ir.core import Operation, Block from xdsl.ir import MLContext from xdsl.passes import ModulePass @@ -10,6 +10,7 @@ op_type_rewrite_pattern, ) from xdsl.traits import SymbolTable +from compiler.util.dispatching_rules import dispatch_to_compute, dispatch_to_dm class DispatchRegionsRewriter(RewritePattern): @@ -58,20 +59,6 @@ def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable): return changes_made - def dispatch_to_dm(op): - """Rule to dispatch operations to the dm core: - for now, this is only memref copy operations""" - if isinstance(op, memref.CopyOp): - return True - return False - - def dispatch_to_compute(op): - """Rule to dispatch operations to the dm core: - for now, this is only linalg generic operations""" - if isinstance(op, linalg.Generic): - return True - return False - # find root module op of func op: # this is necessary to declare the external functions # such as is_dm_core and is_compute core at the correct place diff --git a/compiler/transforms/insert_sync_barrier.py b/compiler/transforms/insert_sync_barrier.py index 05ccf7a6..e539ba7f 100644 --- a/compiler/transforms/insert_sync_barrier.py +++ b/compiler/transforms/insert_sync_barrier.py @@ -1,4 +1,4 @@ -from xdsl.dialects import builtin, memref, linalg +from xdsl.dialects import builtin from compiler.dialects import snax from xdsl.ir import MLContext from xdsl.passes import ModulePass @@ -8,6 +8,7 @@ RewritePattern, op_type_rewrite_pattern, ) +from compiler.util.dispatching_rules import dispatch_to_compute, dispatch_to_dm class InsertSyncBarrierRewriter(RewritePattern): @@ -40,15 +41,10 @@ def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter) # is used on another core - if yes, there must be a synchronisation # barrier between the two ops - # basic dispatching rules for now - def check_core(op): - if isinstance(op, memref.CopyOp): - return "dm" - if isinstance(op, linalg.Generic): - return "compute" - return "global" + if dispatch_to_dm(op) and not dispatch_to_dm(op_use.operation): + ops_to_sync.append(op_use.operation) - if check_core(op) != "global" and check_core(op) != check_core( + if dispatch_to_compute(op) and not dispatch_to_compute( op_use.operation ): ops_to_sync.append(op_use.operation) diff --git a/compiler/util/dispatching_rules.py b/compiler/util/dispatching_rules.py new file mode 100644 index 00000000..e0ea95d0 --- /dev/null +++ b/compiler/util/dispatching_rules.py @@ -0,0 +1,17 @@ +from xdsl.dialects import memref, linalg + + +def dispatch_to_dm(op): + """Rule to dispatch operations to the dm core: + for now, this is only memref copy operations""" + if isinstance(op, memref.CopyOp): + return True + return False + + +def dispatch_to_compute(op): + """Rule to dispatch operations to the dm core: + for now, this is only linalg generic operations""" + if isinstance(op, linalg.Generic): + return True + return False