diff --git a/forge/forge/op/eval/forge/transpose.py b/forge/forge/op/eval/forge/transpose.py index 46641d429..bfb7cb544 100644 --- a/forge/forge/op/eval/forge/transpose.py +++ b/forge/forge/op/eval/forge/transpose.py @@ -63,67 +63,6 @@ def decompose(self, dc, inputs): ) ) - def decompose_post_optimize(self, dc, inputs): - orig_shape = inputs[0].shape - if ( - len(orig_shape) > 2 - and self.dim0 == -3 - and self.dim1 == -1 - and ((len(orig_shape) == 4 and orig_shape[-4] == 1) or len(orig_shape) < 4) - ): - # XZ transpose - result = inputs[0] - use_sparse_mm = True - - result = inputs[0] - result = dc.op("pad_tile", [result], (-2, orig_shape[-2])) - result = dc.op("pad_tile", [result], (-1, orig_shape[-1])) - result = dc.op(TransposeTM.create(-2, -1), [result]) - - if result.shape[-3] > 1: - result = dc.op("vstack", [result], (orig_shape[-3],)) - i_spm = sparse_utils.create_sparse_interleave_picker_matrix( - result.shape[-2], orig_shape[-1], orig_shape[-3] - ) - result = picker_matmul(use_sparse_mm, dc, i_spm, result) - - if orig_shape[-1] > 1: - result = dc.op("vslice", [result], (orig_shape[-1],)) - result = dc.op(TransposeTM.create(-2, -1), [result]) - - result = dc.op("narrow", [result], (-2, 0, orig_shape[-2], result.shape[-2])) - result = dc.op("narrow", [result], (-1, 0, orig_shape[-3], result.shape[-1])) - - dc.fuse(result) - - elif ( - self.dim0 == -3 - and self.dim1 == -2 - and ((len(orig_shape) == 4 and orig_shape[0] == 1) or len(orig_shape) == 3) - ): - # YZ transpose - result = inputs[0] - use_sparse_mm = True - - result = inputs[0] - result = dc.op("pad_tile", [result], (-2, orig_shape[-2])) - result = dc.op("pad_tile", [result], (-1, orig_shape[-1])) - - if result.shape[-3] > 1: - result = dc.op("vstack", [result], (orig_shape[-3],)) - i_spm = sparse_utils.create_sparse_interleave_picker_matrix( - result.shape[-2], orig_shape[-2], orig_shape[-3] - ) - result = picker_matmul(use_sparse_mm, dc, i_spm, result) - - if orig_shape[-2] > 1: - result = dc.op("vslice", [result], (orig_shape[-2],)) - - result = dc.op("narrow", [result], (-2, 0, orig_shape[-3], result.shape[-2])) - result = dc.op("narrow", [result], (-1, 0, orig_shape[-1], result.shape[-1])) - - dc.fuse(result) - def picker_matmul(use_sparse_mm, dc, s, result): if use_sparse_mm: diff --git a/forge/test/mlir/test_ops.py b/forge/test/mlir/test_ops.py index 48dc50257..37e7a8e47 100644 --- a/forge/test/mlir/test_ops.py +++ b/forge/test/mlir/test_ops.py @@ -275,6 +275,9 @@ def forward(self, a, b): ((32, 128), (0, 1)), ((18, 65), (1, 0)), ((6, 33, 34), (-1, 1)), + ((1, 32, 64), (-2, -3)), + ((6, 33, 34), (-1, -3)), + ((32, 128, 24), (1, -3)), ], ) def test_transpose(params): @@ -293,6 +296,10 @@ def forward(self, a): fw_out = framework_model(*inputs) compiled_model = forge.compile(framework_model, sample_inputs=inputs) + + if params[1][1] == -3: + pytest.xfail("Currently the lowering to TTNN is not supported for -3 dim") + co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out]