Skip to content

Commit

Permalink
make use of isa
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 6, 2025
1 parent 6d67bf7 commit bd0105a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
18 changes: 8 additions & 10 deletions compiler/transforms/clear_memory_space.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import builtin, func
from xdsl.ir import Attribute, BlockArgument
from xdsl.passes import ModulePass
from xdsl.utils.hints import isa

from compiler.dialects.tsl import TiledStridedLayoutAttr

Expand All @@ -15,20 +14,19 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
# helper function to clear the memory space of a memref
# also clears the layout information of the memref - not used anymore
def clear_memory_space(t: Attribute) -> Attribute:
if isinstance(t, builtin.MemRefType):
memref_t = cast(builtin.MemRefType[Attribute], t)
if isinstance(memref_t.layout, TiledStridedLayoutAttr):
if isa(t, builtin.MemRefType[Attribute]):
if isinstance(t.layout, TiledStridedLayoutAttr):
return builtin.MemRefType(
memref_t.element_type,
memref_t.get_shape(),
t.element_type,
t.get_shape(),
builtin.NoneAttr(),
builtin.NoneAttr(),
)
else:
return builtin.MemRefType(
memref_t.element_type,
memref_t.get_shape(),
memref_t.layout,
t.element_type,
t.get_shape(),
t.layout,
builtin.NoneAttr(),
)
return t
Expand Down
10 changes: 3 additions & 7 deletions compiler/transforms/convert_tosa_to_kernel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import cast

from xdsl.builder import Builder
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, linalg, tensor, tosa
Expand All @@ -13,6 +11,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.utils.hints import isa

from compiler.dialects import kernel

Expand Down Expand Up @@ -44,14 +43,11 @@ def match_and_rewrite(self, rescale_op: tosa.RescaleOp, rewriter: PatternRewrite
clamp_op = rescale_op

# should have tensor inputs
if not isinstance(inp_type := rescale_op.input.type, builtin.TensorType):
if not isa(inp_type := rescale_op.input.type, builtin.TensorType[Attribute]):
return
if not isinstance(out_type := clamp_op.output.type, builtin.TensorType):
if not isa(out_type := clamp_op.output.type, builtin.TensorType[Attribute]):
return

inp_type = cast(builtin.TensorType[Attribute], inp_type)
out_type = cast(builtin.TensorType[Attribute], out_type)

# create linalg body with kernel op with the params of tosa ops

# Extract all values:
Expand Down

0 comments on commit bd0105a

Please sign in to comment.