From 6bda719363bbe830666ea43095a1c4aed38c0733 Mon Sep 17 00:00:00 2001 From: Kamalraj Kannan <157608228+kamalrajkannan78@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:48:40 +0530 Subject: [PATCH] deprecate sparse matmul op in CNN models (#668) --- forge/forge/op/eval/forge/pooling.py | 21 ++++--- forge/forge/op/eval/forge/resize.py | 89 +--------------------------- forge/test/mlir/test_ops.py | 30 ++++++++++ 3 files changed, 44 insertions(+), 96 deletions(-) diff --git a/forge/forge/op/eval/forge/pooling.py b/forge/forge/op/eval/forge/pooling.py index 4f6c0324b..62f167d06 100644 --- a/forge/forge/op/eval/forge/pooling.py +++ b/forge/forge/op/eval/forge/pooling.py @@ -543,10 +543,14 @@ def decompose(type, attr, dc, inputs): if not padding == [0, 0, 0, 0] and (ceil_mode == True or count_include_pad == False): if channel_last: _, y_out, x_out, _ = (result.shape.w, result.shape.z, result.shape.r, result.shape.c) - result = dc.op("reshape", [result], (w, 1, y_out * x_out, cin)) + result = dc.op_with_named_attrs( + "reshape", [result], {"shape": (w, 1, y_out * x_out, cin)}, (w, 1, y_out * x_out, cin) + ) else: _, _, y_out, x_out = (result.shape.w, result.shape.z, result.shape.r, result.shape.c) - result = dc.op("reshape", [result], (w, 1, cin, y_out * x_out)) + result = dc.op_with_named_attrs( + "reshape", [result], {"shape": (w, 1, cin, y_out * x_out)}, (w, 1, cin, y_out * x_out) + ) result = dc.op(TransposeTM.create(2, 3), [result]) # Since count_include_pad=False undoes math in all padded regions, it takes precedence: @@ -567,16 +571,17 @@ def decompose(type, attr, dc, inputs): tile_align=False, ) undo_math_picker_tensor = dc.tensor(undo_math_picker) - # TODO: This sparse matmul can definitely be fused the same way the sparse mm of convtransposed2d was fused - # Ideally, conv2d op should be aware of the ceil_mode param (convtranspose2d has a similar thing - - # output_padding) as that way it could create this sparse mm itself and easily fuse it - result = dc.op("sparse_matmul", [undo_math_picker_tensor, result]) + result = dc.op("matmul", [undo_math_picker_tensor, result]) if channel_last: - result = dc.op("reshape", [result], (w, y_out, x_out, cin)) + result = dc.op_with_named_attrs( + "reshape", [result], {"shape": (w, y_out, x_out, cin)}, (w, y_out, x_out, cin) + ) else: result = dc.op(TransposeTM.create(2, 3), [result]) - result = dc.op("reshape", [result], (w, cin, y_out, x_out)) + result = dc.op_with_named_attrs( + "reshape", [result], {"shape": (w, cin, y_out, x_out)}, (w, cin, y_out, x_out) + ) dc.fuse(result) diff --git a/forge/forge/op/eval/forge/resize.py b/forge/forge/op/eval/forge/resize.py index b48dd7320..4d1408140 100644 --- a/forge/forge/op/eval/forge/resize.py +++ b/forge/forge/op/eval/forge/resize.py @@ -167,91 +167,6 @@ def backward(type, attr, ac, operand, inputs, output, grad): raise RuntimeError("This should never be called.") -def decompose_upsample_2d(attr, dc, inputs, resize_method): - activations = inputs[0] - shape = inputs[0].shape - channel_last = attr[-1] - if channel_last: - w, y, x, cin = (shape.w, shape.z, shape.r, shape.c) - activations = dc.op("reshape", [activations], (w, 1, y * x, cin)) - scale_factor_y = attr[0] // shape[-3] - scale_factor_x = attr[1] // shape[-2] - scale_factor = (scale_factor_x, scale_factor_y) - else: - w, cin, y, x = (shape.w, shape.z, shape.r, shape.c) - activations = dc.op( - "reshape", - [inputs[0]], - (w, 1, cin, y * x), - ) - activations = dc.op(TransposeTM.create(2, 3), [activations]) - - scale_factor_y = attr[0] // shape[-2] - scale_factor_x = attr[1] // shape[-1] - scale_factor = (scale_factor_x, scale_factor_y) - - if resize_method == "nearest": - dident = create_nearest_neighbor_upsample_picker_matrix(scale_factor, shape, channel_last=channel_last) - dident_tensor = dc.tensor(dident) - result = dc.op("sparse_matmul", [dident_tensor, activations]) - - elif resize_method == "bilinear": - dident = create_bilinear_upsample_picker_matrix( - scale_factor, shape, align_corners=attr[-2], channel_last=channel_last - ) - dident_dense = dident.unsqueeze(0).unsqueeze(0).to_dense() - if int(os.environ.get("FORGE_SPLIT_RESIZE2D", "0")) == inputs[0].shape[-2]: - dd = [] - split_factor = 8 - for s in range(split_factor): - dd.append( - create_bilinear_upsample_picker_matrix( - scale_factor, - shape, - align_corners=attr[-2], - channel_last=channel_last, - split_idx=s, - split_factor=split_factor, - ) - ) - - # Choose whether to use sparse or dense matmul based on sparsity of dident - if torch.count_nonzero(dident_dense) > (torch.numel(dident_dense) // 2) or int( - os.environ.get("FORGE_FORCE_RESIZE_DENSE_MM", "0") - ): - dident_tensor = dc.tensor(dident_dense) - result = dc.op("matmul", [dident_tensor, activations]) - else: - if int(os.environ.get("FORGE_SPLIT_RESIZE2D", "0")) == inputs[0].shape[-2]: - dd_tensor = [dc.tensor(d) for d in dd] - res = [] - for d in dd_tensor: - res.append(dc.op("sparse_matmul", [d, activations])) - result = dc.op("concatenate", res, (-3,)) - result = dc.op("vstack", [result], (len(res),)) - else: - dident_tensor = dc.tensor(dident) - result = dc.op("sparse_matmul", [dident_tensor, activations]) - - if channel_last: - result = dc.op("reshape", [result], (w, y * scale_factor_y, x * scale_factor_x, cin)) - dc.fuse(result) - else: - result = dc.op(TransposeTM.create(2, 3), [result]) - result = dc.op( - "reshape", - [result], - ( - w, - cin, - y * scale_factor_y, - x * scale_factor_x, - ), - ) - - dc.fuse(result) - - def decompose_upsample_3d(attr, dc, inputs, resize_method): activations = inputs[0] shape = inputs[0].shape @@ -376,9 +291,7 @@ def decompose_resize2d(attr, dc, inputs, resize_method): dc.fuse(result) return - if upsample: - decompose_upsample_2d(attr, dc, inputs, resize_method) - else: + if not upsample: decompose_downsample_2d(attr, dc, inputs, resize_method) diff --git a/forge/test/mlir/test_ops.py b/forge/test/mlir/test_ops.py index ab4f2b95c..30b00839a 100644 --- a/forge/test/mlir/test_ops.py +++ b/forge/test/mlir/test_ops.py @@ -14,6 +14,36 @@ from forge.tensor import to_forge_tensors, to_pt_tensors +@pytest.mark.parametrize( + "shape, mode", + [ + ((1, 2048, 7, 7), "nearest"), + ((1, 2048, 7, 7), "bilinear"), + ], +) +@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") +@pytest.mark.push +def test_interpolate(shape, mode): + class interpolate(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return nn.functional.interpolate(x, scale_factor=2, mode=mode) + + inputs = [torch.rand(shape)] + + framework_model = interpolate() + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + @pytest.mark.parametrize( "shape", [