Skip to content

Commit

Permalink
make use of isa
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 6, 2025
1 parent 15b6ef5 commit 135de0b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
9 changes: 3 additions & 6 deletions compiler/transforms/realize_memref_casts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, func, linalg, memref
from xdsl.dialects.memref import MemorySpaceCastOp
Expand All @@ -12,6 +10,7 @@
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.utils.hints import isa

from compiler.dialects import stream
from compiler.dialects.snax import LayoutCast
Expand Down Expand Up @@ -56,11 +55,9 @@ def match_and_rewrite(

# now perform casting by inserting memref copies and allocs
source_type = source_op.source.type
assert isinstance(source_type, builtin.MemRefType)
source_type = cast(builtin.MemRefType[Attribute], source_type)
assert isa(source_type, builtin.MemRefType[Attribute])
dest_type = op.dest.type
assert isinstance(dest_type, builtin.MemRefType)
dest_type = cast(builtin.MemRefType[Attribute], dest_type)
assert isa(dest_type, builtin.MemRefType[Attribute])

# create allocation

Expand Down
19 changes: 8 additions & 11 deletions compiler/transforms/set_memory_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.utils.hints import isa

from compiler.dialects import stream
from compiler.util.snax_memory import L1, L3
Expand Down Expand Up @@ -37,8 +38,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):

# Mapping function to assign default memory space "L3"
def change_to_memory_space(t: Attribute) -> Attribute:
if isinstance(t, builtin.MemRefType):
t = cast(builtin.MemRefType[Attribute], t)
if isa(t, builtin.MemRefType[Attribute]):
if isinstance(t.memory_space, builtin.NoneAttr):
return builtin.MemRefType(
t.element_type,
Expand Down Expand Up @@ -155,10 +155,10 @@ def get_cast_op(operand: SSAValue) -> memref.MemorySpaceCastOp:
cast_op = use.operation
break
# If cast op not found, create and insert new one
assert isinstance(optype := operand.type, builtin.MemRefType)
assert isa(optype := operand.type, builtin.MemRefType[Attribute])
if cast_op is None:
cast_op = memref.MemorySpaceCastOp.from_type_and_target_space(
operand, cast(builtin.MemRefType[Attribute], optype), L1
operand, optype, L1
)
rewriter.insert_op_before_matched_op(cast_op)

Expand Down Expand Up @@ -195,19 +195,16 @@ def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter):
func_op_output = outputs[i]
func_return_output = op.arguments[i]

if not isinstance(func_op_output, builtin.MemRefType):
if not isa(func_op_output, builtin.MemRefType[Attribute]):
new_arguments.append(func_return_output)
continue
if not isinstance(
func_return_output_type := func_return_output.type, builtin.MemRefType
if not isa(
func_return_output_type := func_return_output.type,
builtin.MemRefType[Attribute],
):
new_arguments.append(func_return_output)
continue

func_return_output_type = cast(
builtin.MemRefType[Attribute], func_return_output_type
)

if func_op_output.memory_space != func_return_output_type.memory_space:
# create cast op
cast_op = memref.MemorySpaceCastOp.from_type_and_target_space(
Expand Down

0 comments on commit 135de0b

Please sign in to comment.