diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp index 0cad3c8db8e..403c2623650 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -72,8 +72,7 @@ namespace sinf * - If the `shape` is CircleOutputDummy, the shape is inferred from: * - the attribute if it exists. * - the node itself if the attribute does not exist. - * - Else, the shape is inferred from the node iteself. - * - TODO support CircleNode + * - If the `shape` is CircleNode, the dynamic shape is propagated. */ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) { @@ -81,6 +80,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) const loco::DataType S32 = loco::DataType::S32; + bool should_infer = true; loco::TensorShape shape_by_input; { LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); @@ -127,21 +127,16 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) } else { - // We use shape from the node itself - loco::TensorShape shape; - shape.rank(node->rank()); - for (uint32_t r = 0; r < node->rank(); ++r) + auto node_shape = loco::must_cast(node->shape()); + + shape_by_input.rank(node_shape->dim(0).value()); + + for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) { - // TODO remove this copy from `use_own(node);` - // Shape inference rules in this file did not consider unknown dimension. - // If some node has unknown dimension, 0 is inserted and wrong shape - // inference was done as a result. - // To fix this, new shape inference algorithm is being implemented. - // Until new inference algorithm is fully implemented, unknown dimension - // would be represented as 1 along with TFLite expression. - shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1; + shape_by_input.dim(axis).unset(); } - shape_by_input = shape; + + should_infer = false; } } @@ -170,7 +165,6 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) uint32_t input_element_count = 1; uint32_t output_element_count = 1; uint32_t unknown_dim_index = UINT32_MAX; - bool should_infer = true; for (uint32_t i = 0; i < input_shape.rank(); ++i) { if (input_shape.dim(i).known()) @@ -178,24 +172,28 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) else should_infer = false; } - for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) + + if (should_infer) { - const uint32_t dim_value = output_shape.dim(dim_index).value(); - if (output_shape.dim(dim_index).known() == false) + for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) { - LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); - unknown_dim_index = dim_index; + const uint32_t dim_value = output_shape.dim(dim_index).value(); + if (output_shape.dim(dim_index).known() == false) + { + LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); + unknown_dim_index = dim_index; + } + else + { + output_element_count *= dim_value; + } } - else + if (unknown_dim_index != UINT32_MAX) { - output_element_count *= dim_value; + output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; } } - if (unknown_dim_index != UINT32_MAX && should_infer) - { - output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; - } - + return output_shape; } diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index 8f6175cfb9e..4393d4bfbe0 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -198,3 +198,31 @@ TEST(ShapeRuleTest, reshape_by_dummy_dynamic) ASSERT_EQ(6, output_shape.dim(0).value()); ASSERT_EQ(4, output_shape.dim(1).value()); } + +TEST(ShapeRuleTest, reshape_by_node) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_node = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_node->dtype(loco::DataType::S32); + shape_node->shape({2}); + shape_node->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_node); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_FALSE(output_shape.dim(0).known()); + ASSERT_FALSE(output_shape.dim(1).known()); +}