Skip to content

Commit

Permalink
add bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 6, 2025
1 parent 5e934e3 commit 2020d01
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
27 changes: 21 additions & 6 deletions compiler/dialects/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def __init__(
result_types=[result_types],
)


@irdl_op_definition
class StreamingRegionOp(StreamingRegionOpBase):
name = "stream.streaming_region"

def get_pattern_bounds_to_shapes_map(self) -> AffineMap:
"""
Returns mapping from pattern iteration bounds to operand shapes
Expand Down Expand Up @@ -149,15 +154,14 @@ def get_static_pattern_bounds(self) -> Iterable[int]:
)


@irdl_op_definition
class StreamingRegionOp(StreamingRegionOpBase):
name = "stream.streaming_region"


@irdl_op_definition
class ScheduleOp(StreamingRegionOpBase):
name = "stream.schedule"

# The bounds of the iteration space of the schedule
bounds = prop_def(ParameterDef[ArrayAttr[IntegerAttr[IndexType]]])

# The tiling factors for the different dimensions of inputs and outputs
tiles = prop_def(ParameterDef[ArrayAttr[ArrayAttr[IntegerAttr[IndexType]]]])

def __init__(
Expand All @@ -166,6 +170,7 @@ def __init__(
outputs: Sequence[SSAValue | Operation],
patterns: ArrayAttr[AffineMapAttr],
body: Region,
bounds: ArrayAttr[IntegerAttr[IndexType]] | Sequence[int],
tiles: ArrayAttr[ArrayAttr[IntegerAttr[IndexType]]] | Sequence[Sequence[int]],
accelerator: str | StringAttr | None = None,
result_types: Sequence[Attribute] = (),
Expand All @@ -177,8 +182,18 @@ def __init__(
for tile in tiles
]
)
if isinstance(bounds, Sequence):
bounds = ArrayAttr(
[IntegerAttr.from_index_int_value(val) for val in bounds]
)
super().__init__(
inputs, outputs, patterns, body, accelerator, result_types, {"tiles": tiles}
inputs,
outputs,
patterns,
body,
accelerator,
result_types,
{"tiles": tiles, "bounds": bounds},
)


Expand Down
4 changes: 2 additions & 2 deletions tests/filecheck/dialects/stream/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
%s0, %s1, %s2 = "test.op"() : () -> (!stream.stream<i8>, !stream.stream<i32>, !stream.stream<f32>)
%t0, %t1 = "test.op"() : () -> (tensor<16x16xi8>, tensor<16x16xi32>)

%0 = "stream.schedule"(%t0, %t0, %t1) <{"tiles" = [[1 : index, 3: index], [5: index, 7: index]], "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], "accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>}> ({
%0 = "stream.schedule"(%t0, %t0, %t1) <{"bounds" = [1: index, 2: index, 3: index], "tiles" = [[1 : index, 3: index], [5: index, 7: index]], "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], "accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>}> ({
^0(%1 : !stream.stream<i8>, %2 : !stream.stream<i8>, %3 : !stream.stream<i32>):
%4 = "stream.generic"(%1, %2) ({
^1(%in : i8, %in_1 : i8):
Expand All @@ -48,7 +48,7 @@
// CHECK: builtin.module {
// CHECK-NEXT: %s0, %s1, %s2 = "test.op"() : () -> (!stream.stream<i8>, !stream.stream<i32>, !stream.stream<f32>)
// CHECK-NEXT: %t0, %t1 = "test.op"() : () -> (tensor<16x16xi8>, tensor<16x16xi32>)
// CHECK-NEXT: %0 = "stream.schedule"(%t0, %t0, %t1) <{"tiles" = [[1 : index, 3 : index], [5 : index, 7 : index]], "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], "accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-NEXT: %0 = "stream.schedule"(%t0, %t0, %t1) <{"bounds" = [1 : index, 2 : index, 3 : index], "tiles" = [[1 : index, 3 : index], [5 : index, 7 : index]], "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], "accelerator" = "snax_gemmx_stream", "operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-NEXT: ^0(%1 : !stream.stream<i8>, %2 : !stream.stream<i8>, %3 : !stream.stream<i32>):
// CHECK-NEXT: %4 = "stream.generic"(%1, %2) ({
// CHECK-NEXT: ^1(%in : i8, %in_1 : i8):
Expand Down

0 comments on commit 2020d01

Please sign in to comment.