Skip to content

Commit

Permalink
deprecate sparse matmul op in CNN models (#668)
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalrajkannan78 authored Nov 12, 2024
1 parent 43533af commit 6bda719
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 96 deletions.
21 changes: 13 additions & 8 deletions forge/forge/op/eval/forge/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,14 @@ def decompose(type, attr, dc, inputs):
if not padding == [0, 0, 0, 0] and (ceil_mode == True or count_include_pad == False):
if channel_last:
_, y_out, x_out, _ = (result.shape.w, result.shape.z, result.shape.r, result.shape.c)
result = dc.op("reshape", [result], (w, 1, y_out * x_out, cin))
result = dc.op_with_named_attrs(
"reshape", [result], {"shape": (w, 1, y_out * x_out, cin)}, (w, 1, y_out * x_out, cin)
)
else:
_, _, y_out, x_out = (result.shape.w, result.shape.z, result.shape.r, result.shape.c)
result = dc.op("reshape", [result], (w, 1, cin, y_out * x_out))
result = dc.op_with_named_attrs(
"reshape", [result], {"shape": (w, 1, cin, y_out * x_out)}, (w, 1, cin, y_out * x_out)
)
result = dc.op(TransposeTM.create(2, 3), [result])

# Since count_include_pad=False undoes math in all padded regions, it takes precedence:
Expand All @@ -567,16 +571,17 @@ def decompose(type, attr, dc, inputs):
tile_align=False,
)
undo_math_picker_tensor = dc.tensor(undo_math_picker)
# TODO: This sparse matmul can definitely be fused the same way the sparse mm of convtransposed2d was fused
# Ideally, conv2d op should be aware of the ceil_mode param (convtranspose2d has a similar thing -
# output_padding) as that way it could create this sparse mm itself and easily fuse it
result = dc.op("sparse_matmul", [undo_math_picker_tensor, result])
result = dc.op("matmul", [undo_math_picker_tensor, result])

if channel_last:
result = dc.op("reshape", [result], (w, y_out, x_out, cin))
result = dc.op_with_named_attrs(
"reshape", [result], {"shape": (w, y_out, x_out, cin)}, (w, y_out, x_out, cin)
)
else:
result = dc.op(TransposeTM.create(2, 3), [result])
result = dc.op("reshape", [result], (w, cin, y_out, x_out))
result = dc.op_with_named_attrs(
"reshape", [result], {"shape": (w, cin, y_out, x_out)}, (w, cin, y_out, x_out)
)

dc.fuse(result)

Expand Down
89 changes: 1 addition & 88 deletions forge/forge/op/eval/forge/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,91 +167,6 @@ def backward(type, attr, ac, operand, inputs, output, grad):
raise RuntimeError("This should never be called.")


def decompose_upsample_2d(attr, dc, inputs, resize_method):
activations = inputs[0]
shape = inputs[0].shape
channel_last = attr[-1]
if channel_last:
w, y, x, cin = (shape.w, shape.z, shape.r, shape.c)
activations = dc.op("reshape", [activations], (w, 1, y * x, cin))
scale_factor_y = attr[0] // shape[-3]
scale_factor_x = attr[1] // shape[-2]
scale_factor = (scale_factor_x, scale_factor_y)
else:
w, cin, y, x = (shape.w, shape.z, shape.r, shape.c)
activations = dc.op(
"reshape",
[inputs[0]],
(w, 1, cin, y * x),
)
activations = dc.op(TransposeTM.create(2, 3), [activations])

scale_factor_y = attr[0] // shape[-2]
scale_factor_x = attr[1] // shape[-1]
scale_factor = (scale_factor_x, scale_factor_y)

if resize_method == "nearest":
dident = create_nearest_neighbor_upsample_picker_matrix(scale_factor, shape, channel_last=channel_last)
dident_tensor = dc.tensor(dident)
result = dc.op("sparse_matmul", [dident_tensor, activations])

elif resize_method == "bilinear":
dident = create_bilinear_upsample_picker_matrix(
scale_factor, shape, align_corners=attr[-2], channel_last=channel_last
)
dident_dense = dident.unsqueeze(0).unsqueeze(0).to_dense()
if int(os.environ.get("FORGE_SPLIT_RESIZE2D", "0")) == inputs[0].shape[-2]:
dd = []
split_factor = 8
for s in range(split_factor):
dd.append(
create_bilinear_upsample_picker_matrix(
scale_factor,
shape,
align_corners=attr[-2],
channel_last=channel_last,
split_idx=s,
split_factor=split_factor,
)
)

# Choose whether to use sparse or dense matmul based on sparsity of dident
if torch.count_nonzero(dident_dense) > (torch.numel(dident_dense) // 2) or int(
os.environ.get("FORGE_FORCE_RESIZE_DENSE_MM", "0")
):
dident_tensor = dc.tensor(dident_dense)
result = dc.op("matmul", [dident_tensor, activations])
else:
if int(os.environ.get("FORGE_SPLIT_RESIZE2D", "0")) == inputs[0].shape[-2]:
dd_tensor = [dc.tensor(d) for d in dd]
res = []
for d in dd_tensor:
res.append(dc.op("sparse_matmul", [d, activations]))
result = dc.op("concatenate", res, (-3,))
result = dc.op("vstack", [result], (len(res),))
else:
dident_tensor = dc.tensor(dident)
result = dc.op("sparse_matmul", [dident_tensor, activations])

if channel_last:
result = dc.op("reshape", [result], (w, y * scale_factor_y, x * scale_factor_x, cin))
dc.fuse(result)
else:
result = dc.op(TransposeTM.create(2, 3), [result])
result = dc.op(
"reshape",
[result],
(
w,
cin,
y * scale_factor_y,
x * scale_factor_x,
),
)

dc.fuse(result)


def decompose_upsample_3d(attr, dc, inputs, resize_method):
activations = inputs[0]
shape = inputs[0].shape
Expand Down Expand Up @@ -376,9 +291,7 @@ def decompose_resize2d(attr, dc, inputs, resize_method):
dc.fuse(result)
return

if upsample:
decompose_upsample_2d(attr, dc, inputs, resize_method)
else:
if not upsample:
decompose_downsample_2d(attr, dc, inputs, resize_method)


Expand Down
30 changes: 30 additions & 0 deletions forge/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,36 @@
from forge.tensor import to_forge_tensors, to_pt_tensors


@pytest.mark.parametrize(
"shape, mode",
[
((1, 2048, 7, 7), "nearest"),
((1, 2048, 7, 7), "bilinear"),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
@pytest.mark.push
def test_interpolate(shape, mode):
class interpolate(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return nn.functional.interpolate(x, scale_factor=2, mode=mode)

inputs = [torch.rand(shape)]

framework_model = interpolate()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"shape",
[
Expand Down

0 comments on commit 6bda719

Please sign in to comment.