diff --git a/compiler/transforms/dispatch_regions.py b/compiler/transforms/dispatch_regions.py index 5d93d551..e17657a7 100644 --- a/compiler/transforms/dispatch_regions.py +++ b/compiler/transforms/dispatch_regions.py @@ -15,11 +15,11 @@ class DispatchRegionsRewriter(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, func_op: func.FuncOp, rewriter: PatternRewriter): - """Helper function to create dispatches in a block. If an operation is - dispatchable according to dispatch_rule, this function will enclose it in - an scf.if block based on the condition core_cond""" - def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable): + """Helper function to create dispatches in a block. If an operation is + dispatchable according to dispatch_rule, this function will enclose it in + an scf.if block based on the condition core_cond""" + # return if the dispatcher made any changes changes_made = False @@ -58,18 +58,16 @@ def dispatcher(block: Block, core_cond: builtin.i1, dispatch_rule: Callable): return changes_made - """ Rule to dispatch operations to the dm core: - for now, this is only memref copy operations """ - def dispatch_to_dm(op): + """Rule to dispatch operations to the dm core: + for now, this is only memref copy operations""" if isinstance(op, memref.CopyOp): return True return False - """ Rule to dispatch operations to the dm core: - for now, this is only linalg generic operations """ - def dispatch_to_compute(op): + """Rule to dispatch operations to the dm core: + for now, this is only linalg generic operations""" if isinstance(op, linalg.Generic): return True return False