Skip to content

Commit

Permalink
Deprecated decompose of Transpose op (#543)
Browse files Browse the repository at this point in the history
* deleted depricated decompose

* test for -3 dim

* early stop change

* handling -3 case

* xfail instead of return
  • Loading branch information
mstojkovicTT authored Oct 29, 2024
1 parent 8079c41 commit 89eb521
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 61 deletions.
61 changes: 0 additions & 61 deletions forge/forge/op/eval/forge/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,67 +63,6 @@ def decompose(self, dc, inputs):
)
)

def decompose_post_optimize(self, dc, inputs):
orig_shape = inputs[0].shape
if (
len(orig_shape) > 2
and self.dim0 == -3
and self.dim1 == -1
and ((len(orig_shape) == 4 and orig_shape[-4] == 1) or len(orig_shape) < 4)
):
# XZ transpose
result = inputs[0]
use_sparse_mm = True

result = inputs[0]
result = dc.op("pad_tile", [result], (-2, orig_shape[-2]))
result = dc.op("pad_tile", [result], (-1, orig_shape[-1]))
result = dc.op(TransposeTM.create(-2, -1), [result])

if result.shape[-3] > 1:
result = dc.op("vstack", [result], (orig_shape[-3],))
i_spm = sparse_utils.create_sparse_interleave_picker_matrix(
result.shape[-2], orig_shape[-1], orig_shape[-3]
)
result = picker_matmul(use_sparse_mm, dc, i_spm, result)

if orig_shape[-1] > 1:
result = dc.op("vslice", [result], (orig_shape[-1],))
result = dc.op(TransposeTM.create(-2, -1), [result])

result = dc.op("narrow", [result], (-2, 0, orig_shape[-2], result.shape[-2]))
result = dc.op("narrow", [result], (-1, 0, orig_shape[-3], result.shape[-1]))

dc.fuse(result)

elif (
self.dim0 == -3
and self.dim1 == -2
and ((len(orig_shape) == 4 and orig_shape[0] == 1) or len(orig_shape) == 3)
):
# YZ transpose
result = inputs[0]
use_sparse_mm = True

result = inputs[0]
result = dc.op("pad_tile", [result], (-2, orig_shape[-2]))
result = dc.op("pad_tile", [result], (-1, orig_shape[-1]))

if result.shape[-3] > 1:
result = dc.op("vstack", [result], (orig_shape[-3],))
i_spm = sparse_utils.create_sparse_interleave_picker_matrix(
result.shape[-2], orig_shape[-2], orig_shape[-3]
)
result = picker_matmul(use_sparse_mm, dc, i_spm, result)

if orig_shape[-2] > 1:
result = dc.op("vslice", [result], (orig_shape[-2],))

result = dc.op("narrow", [result], (-2, 0, orig_shape[-3], result.shape[-2]))
result = dc.op("narrow", [result], (-1, 0, orig_shape[-1], result.shape[-1]))

dc.fuse(result)


def picker_matmul(use_sparse_mm, dc, s, result):
if use_sparse_mm:
Expand Down
7 changes: 7 additions & 0 deletions forge/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def forward(self, a, b):
((32, 128), (0, 1)),
((18, 65), (1, 0)),
((6, 33, 34), (-1, 1)),
((1, 32, 64), (-2, -3)),
((6, 33, 34), (-1, -3)),
((32, 128, 24), (1, -3)),
],
)
def test_transpose(params):
Expand All @@ -293,6 +296,10 @@ def forward(self, a):
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)

if params[1][1] == -3:
pytest.xfail("Currently the lowering to TTNN is not supported for -3 dim")

co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
Expand Down

0 comments on commit 89eb521

Please sign in to comment.