From 652174ed2acdc1d1a384f71414430269e0e5ddbd Mon Sep 17 00:00:00 2001 From: Kamalraj Kannan <157608228+kamalrajkannan78@users.noreply.github.com> Date: Mon, 11 Nov 2024 11:25:47 +0530 Subject: [PATCH] fix shape attribute issue in reshape op in hoist_unsqueeze_squeeze_to_reshape (#664) --- forge/csrc/passes/explicate_unsqueeze.cpp | 5 ++- forge/test/mlir/test_ops.py | 37 +++++++++++++++++++ .../high_prio/cnn/pytorch/test_rcnn.py | 2 +- .../high_prio/cnn/pytorch/test_vgg.py | 4 +- 4 files changed, 44 insertions(+), 4 deletions(-) diff --git a/forge/csrc/passes/explicate_unsqueeze.cpp b/forge/csrc/passes/explicate_unsqueeze.cpp index 957127fd3..3d145d30b 100644 --- a/forge/csrc/passes/explicate_unsqueeze.cpp +++ b/forge/csrc/passes/explicate_unsqueeze.cpp @@ -129,7 +129,10 @@ void hoist_unsqueeze_squeeze_to_reshape(graphlib::Graph *graph) { new_reshape_attr.push_back((int)dim); } - op->change_op_type(graphlib::OpType("reshape", new_reshape_attr)); + std::vector shape_vector(target_shape.begin(), target_shape.end()); + graphlib::OpType::Attrs named_attrs; + named_attrs["shape"] = shape_vector; + op->change_op_type(graphlib::OpType("reshape", new_reshape_attr, {}, named_attrs)); op->set_shape(user_op->shape()); nodes_to_remove.insert(users[0]); } diff --git a/forge/test/mlir/test_ops.py b/forge/test/mlir/test_ops.py index b7ade500f..ab4f2b95c 100644 --- a/forge/test/mlir/test_ops.py +++ b/forge/test/mlir/test_ops.py @@ -14,6 +14,43 @@ from forge.tensor import to_forge_tensors, to_pt_tensors +@pytest.mark.parametrize( + "shape", + [ + (1, 256, 6, 6), + (1, 3, 64, 64), + (1, 512, 14, 14), + (1, 3, 224, 224), + (2, 256, 10, 10), + (1, 512, 3, 3), + (1, 1000, 1, 1), + (2, 128, 8, 8), + (4, 1, 32, 32), + (8, 64, 32, 32), + ], +) +@pytest.mark.push +def test_flatten(shape): + class flatten(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.flatten(x, 1) + + inputs = [torch.rand(shape)] + + framework_model = flatten() + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + @pytest.mark.parametrize("operand_and_cast_dtype", [(torch.float32, torch.int32), (torch.int32, torch.float32)]) @pytest.mark.push def test_cast(operand_and_cast_dtype): diff --git a/forge/test/model_demos/high_prio/cnn/pytorch/test_rcnn.py b/forge/test/model_demos/high_prio/cnn/pytorch/test_rcnn.py index 1407aace8..9d4492553 100644 --- a/forge/test/model_demos/high_prio/cnn/pytorch/test_rcnn.py +++ b/forge/test/model_demos/high_prio/cnn/pytorch/test_rcnn.py @@ -59,7 +59,7 @@ def test_rcnn_pytorch(test_device): # Forge configuration parameters compiler_cfg = forge.config._get_global_compiler_config() - compiler_cfg.compile_depth = forge.CompileDepth.GENERATE_INITIAL_GRAPH + compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH # Proposals generated by selective search were fed to a model in a loop manner to compute features. # [Refer line No.151 in https://github.com/object-detection-algorithm/R-CNN/blob/master/py/car_detector.py] diff --git a/forge/test/model_demos/high_prio/cnn/pytorch/test_vgg.py b/forge/test/model_demos/high_prio/cnn/pytorch/test_vgg.py index b3db7c046..a4d14561c 100644 --- a/forge/test/model_demos/high_prio/cnn/pytorch/test_vgg.py +++ b/forge/test/model_demos/high_prio/cnn/pytorch/test_vgg.py @@ -64,7 +64,7 @@ def test_vgg_19_hf_pytorch(test_device): # STEP 1: Set Forge configuration parameters compiler_cfg = forge.config._get_global_compiler_config() # load global compiler config object - compiler_cfg.compile_depth = forge.CompileDepth.GENERATE_INITIAL_GRAPH + compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH """ # https://pypi.org/project/vgg-pytorch/ @@ -136,7 +136,7 @@ def test_vgg_bn19_torchhub_pytorch(test_device): # STEP 1: Set Forge configuration parameters compiler_cfg = forge.config._get_global_compiler_config() # load global compiler config object - compiler_cfg.compile_depth = forge.CompileDepth.GENERATE_INITIAL_GRAPH + compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH model = download_model(torch.hub.load, "pytorch/vision:v0.10.0", "vgg19_bn", pretrained=True) model.eval()