Skip to content

Commit

Permalink
[luci/service] Support dynamic shape inference for reshape
Browse files Browse the repository at this point in the history
This commit supports dynamic shape inference for reshape operation

ONE-DCO-1.0-Signed-off-by: Jongwon Yang <[email protected]>
  • Loading branch information
jongwonyang committed Sep 13, 2024
1 parent 5149447 commit 29ade8e
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 98 deletions.
170 changes: 88 additions & 82 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,7 @@
#include "CircleShapeInferenceHelper.h"
#include "CircleCloneNode.h"

#include <luci/Log.h>

namespace
{

std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
{
os << "[";
for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
{
if (r)
os << ",";

if (tensor_shape.dim(r).known())
os << tensor_shape.dim(r).value();
else
os << "?";
}
os << "]";
return os;
}

} // namespace
#include <oops/InternalExn.h>

namespace luci
{
Expand All @@ -65,93 +43,121 @@ luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReshape *node)
namespace sinf
{

/**
* @note CircleReshape always has two inputs: `tensor` and `shape`.
* The `shape` input can be CircleConst, CircleOutputDummy, or CircleNode.
* - If the `shape` input is CircleConst, the shape is inferred from the constant.
* - If the `shape` input is CircleOutputDummy, the shape is inferred from
* the attribute if it exists. If the attribute does not exist,
* the shape is inferred from the node iteself.
* - If the `shape` input is CircleNode, the dynamic shape is propagated.
*/
loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
{
LOGGER(l);

const loco::DataType S32 = loco::DataType::S32;

loco::TensorShape shape_by_input;
{
LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
// CircleReshape node must have `shape` input
LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");

// Only support node's shape() is CircleConst with S32
// TODO support other node with other types
auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
if (const_shape_node != nullptr)
bool should_infer = true;
loco::TensorShape output_shape;
{
// Check if `shape` is CircleConst
auto const_shape = dynamic_cast<luci::CircleConst *>(node->shape());
if (const_shape != nullptr)
{
LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");
LUCI_ASSERT(const_shape->dtype() == S32, "Only support int32 CircleConst");
output_shape.rank(const_shape->size<S32>());

shape_by_input.rank(const_shape_node->size<S32>());

for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
for (uint32_t axis = 0; axis < output_shape.rank(); ++axis)
{
shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
output_shape.dim(axis) = const_shape->at<S32>(axis);
if (const_shape->at<S32>(axis) < 0)
{
output_shape.dim(axis).unset();
}
}
}
else
{
// We use shape from the node itself
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
// Check if `shape` is CircleOutputDummy
auto dummy_shape = dynamic_cast<luci::CircleOutputDummy *>(node->shape());
if (dummy_shape != nullptr)
{
// TODO remove this copy from `use_own(node);`
// Shape inference rules in this file did not consider unknown dimension.
// If some node has unknown dimension, 0 is inserted and wrong shape
// inference was done as a result.
// To fix this, new shape inference algorithm is being implemented.
// Until new inference algorithm is fully implemented, unknown dimension
// would be represented as 1 along with TFLite expression.
shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
if (node->newShape()->rank() > 0)
{
output_shape.rank(node->newShape()->rank());

for (uint32_t axis = 0; axis < output_shape.rank(); ++axis)
{
output_shape.dim(axis) = node->newShape()->dim(axis);
if (node->newShape()->dim(axis) < 0)
{
output_shape.dim(axis).unset();
}
}
}
else
{
output_shape = circle_shape(node);
}
}
else
{
// Check if `shape` is CircleNode
auto node_shape = dynamic_cast<luci::CircleNode *>(node->shape());
if (node_shape != nullptr)
{
output_shape.rank(node_shape->dim(0).value());

for (uint32_t axis = 0; axis < output_shape.rank(); ++axis)
{
output_shape.dim(axis).unset();
}

should_infer = false;
}
}
shape_by_input = shape;
}
}

loco::TensorShape shape_by_attr;
{
shape_by_attr.rank(node->newShape()->rank());

for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
{
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
}
}

if (!(shape_by_input == shape_by_attr))
{
INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
INFO(l) << " shape_by_input : " << shape_by_input << std::endl;
INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl;
}

loco::TensorShape output_shape = shape_by_input;

// One of the dimensions can have special value -1, meaning its actual value should be inferred.
const auto input = loco::must_cast<luci::CircleNode *>(node->tensor());
const auto input_shape = circle_shape(input);
uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
for (uint32_t i = 0; i < input_shape.rank(); ++i)
input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
for (uint32_t axis = 0; axis < input_shape.rank(); ++axis)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
if (static_cast<int>(dim_value) == -1)
if (input_shape.dim(axis).known())
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
input_element_count *= input_shape.dim(axis).value();
}
else
{
output_element_count *= dim_value;
should_infer = false;
break;
}
}
if (unknown_dim_index != UINT32_MAX)

if (should_infer)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
if (output_shape.dim(dim_index).known() == false)
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
}
else
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
output_element_count *= dim_value;
}
}
if (unknown_dim_index != UINT32_MAX)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}
}

return output_shape;
Expand Down
112 changes: 96 additions & 16 deletions compiler/luci/service/src/Nodes/CircleReshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,25 @@ TEST(CloneNodeTest, clone_Reshape)
ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1));
}

TEST(ShapeRuleTest, reshape_by_input_const_static)
TEST(ShapeRuleTest, reshape_by_circle_const)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();
auto shape_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = 6;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);
shape_input->dtype(loco::DataType::S32);
shape_input->size<loco::DataType::S32>(2);
shape_input->at<loco::DataType::S32>(0) = -1;
shape_input->at<loco::DataType::S32>(1) = 4;
shape_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);
node_reshape->shape(shape_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;
Expand All @@ -71,25 +71,25 @@ TEST(ShapeRuleTest, reshape_by_input_const_static)
ASSERT_EQ(4, output_shape.dim(1).value());
}

TEST(ShapeRuleTest, reshape_by_input_const_dynamic)
TEST(ShapeRuleTest, reshape_by_circle_dummy)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();
auto shape_input = g->nodes()->create<luci::CircleOutputDummy>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = -1;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);
shape_input->dtype(loco::DataType::S32);
shape_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);
node_reshape->shape(shape_input);
node_reshape->newShape()->rank(2);
node_reshape->newShape()->dim(0) = -1;
node_reshape->newShape()->dim(1) = 4;

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;
Expand All @@ -102,3 +102,83 @@ TEST(ShapeRuleTest, reshape_by_input_const_dynamic)
ASSERT_EQ(6, output_shape.dim(0).value());
ASSERT_EQ(4, output_shape.dim(1).value());
}

TEST(ShapeRuleTest, reshape_by_circle_node)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_input = g->nodes()->create<luci::CircleInput>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_input->dtype(loco::DataType::S32);
shape_input->shape({2});
shape_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));

ASSERT_EQ(2, output_shape.rank());
ASSERT_FALSE(output_shape.dim(0).known());
ASSERT_FALSE(output_shape.dim(1).known());
}

TEST(ShapeRuleTest, reshape_input_tensor_undefined_NEG)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::UNDEFINED);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = 6;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape));
}

TEST(ShapeRuleTest, reshape_input_shape_undefined_NEG)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = 6;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::UNDEFINED);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape));
}

0 comments on commit 29ade8e

Please sign in to comment.