Skip to content

Commit

Permalink
fix shape attribute issue in reshape op in hoist_unsqueeze_squeeze_to…
Browse files Browse the repository at this point in the history
…_reshape (#664)
  • Loading branch information
kamalrajkannan78 authored Nov 11, 2024
1 parent 64b912c commit 652174e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 4 deletions.
5 changes: 4 additions & 1 deletion forge/csrc/passes/explicate_unsqueeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ void hoist_unsqueeze_squeeze_to_reshape(graphlib::Graph *graph)
{
new_reshape_attr.push_back((int)dim);
}
op->change_op_type(graphlib::OpType("reshape", new_reshape_attr));
std::vector<int> shape_vector(target_shape.begin(), target_shape.end());
graphlib::OpType::Attrs named_attrs;
named_attrs["shape"] = shape_vector;
op->change_op_type(graphlib::OpType("reshape", new_reshape_attr, {}, named_attrs));
op->set_shape(user_op->shape());
nodes_to_remove.insert(users[0]);
}
Expand Down
37 changes: 37 additions & 0 deletions forge/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,43 @@
from forge.tensor import to_forge_tensors, to_pt_tensors


@pytest.mark.parametrize(
"shape",
[
(1, 256, 6, 6),
(1, 3, 64, 64),
(1, 512, 14, 14),
(1, 3, 224, 224),
(2, 256, 10, 10),
(1, 512, 3, 3),
(1, 1000, 1, 1),
(2, 128, 8, 8),
(4, 1, 32, 32),
(8, 64, 32, 32),
],
)
@pytest.mark.push
def test_flatten(shape):
class flatten(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.flatten(x, 1)

inputs = [torch.rand(shape)]

framework_model = flatten()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize("operand_and_cast_dtype", [(torch.float32, torch.int32), (torch.int32, torch.float32)])
@pytest.mark.push
def test_cast(operand_and_cast_dtype):
Expand Down
2 changes: 1 addition & 1 deletion forge/test/model_demos/high_prio/cnn/pytorch/test_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_rcnn_pytorch(test_device):

# Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.GENERATE_INITIAL_GRAPH
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Proposals generated by selective search were fed to a model in a loop manner to compute features.
# [Refer line No.151 in https://github.com/object-detection-algorithm/R-CNN/blob/master/py/car_detector.py]
Expand Down
4 changes: 2 additions & 2 deletions forge/test/model_demos/high_prio/cnn/pytorch/test_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_vgg_19_hf_pytorch(test_device):

# STEP 1: Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config() # load global compiler config object
compiler_cfg.compile_depth = forge.CompileDepth.GENERATE_INITIAL_GRAPH
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

"""
# https://pypi.org/project/vgg-pytorch/
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_vgg_bn19_torchhub_pytorch(test_device):

# STEP 1: Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config() # load global compiler config object
compiler_cfg.compile_depth = forge.CompileDepth.GENERATE_INITIAL_GRAPH
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

model = download_model(torch.hub.load, "pytorch/vision:v0.10.0", "vgg19_bn", pretrained=True)
model.eval()
Expand Down

0 comments on commit 652174e

Please sign in to comment.