forked from openvinotoolkit/openvino_tokenizers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TF FE] Extend conversion for RaggedTensorToTensor in case rowids for…
…mat and introduce Equal for 1D string tensors (openvinotoolkit#70) * [TF FE] Extend conversion for RaggedTensorToTensor in case rowids format and introduce Equal for 1D string tensors Signed-off-by: Kazantsev, Roman <[email protected]> * Fix conversion of Equal operation Signed-off-by: Kazantsev, Roman <[email protected]> * Fix RaggedToRagged operation Signed-off-by: Kazantsev, Roman <[email protected]> * Fix RaggedToRagged operation extension Signed-off-by: Kazantsev, Roman <[email protected]> * Fix conversion for RaggedTensorToTensor operation Signed-off-by: Kazantsev, Roman <[email protected]> --------- Signed-off-by: Kazantsev, Roman <[email protected]>
- Loading branch information
Showing
8 changed files
with
315 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "equal_str.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace ov; | ||
|
||
|
||
void EqualStr::validate_and_infer_types() { | ||
OPENVINO_ASSERT(get_input_size() == 6); | ||
|
||
auto begins_type1 = this->get_input_element_type(0); | ||
auto ends_type1 = this->get_input_element_type(1); | ||
auto begins_type2 = this->get_input_element_type(3); | ||
auto ends_type2 = this->get_input_element_type(4); | ||
|
||
OPENVINO_ASSERT(begins_type1 == element::i32 && begins_type2 == element::i32, | ||
"Expected an i32 begins for string tensor representation."); | ||
OPENVINO_ASSERT(ends_type1 == element::i32 && ends_type2 == element::i32, | ||
"Expected an i32 ends for string tensor representation."); | ||
|
||
set_output_type(0, ov::element::boolean, PartialShape({ Dimension::dynamic() })); | ||
} | ||
|
||
bool EqualStr::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const { | ||
auto begins1 = inputs[0].data<const int32_t>(); | ||
auto ends1 = inputs[1].data<const int32_t>(); | ||
auto chars1 = inputs[2].data<const uint8_t>(); | ||
auto begins2 = inputs[3].data<const int32_t>(); | ||
auto ends2 = inputs[4].data<const int32_t>(); | ||
auto chars2 = inputs[5].data<const uint8_t>(); | ||
|
||
size_t num_elems1 = inputs[0].get_size(); | ||
size_t num_elems2 = inputs[3].get_size(); | ||
size_t num_elems = std::max(num_elems1, num_elems2); | ||
outputs[0].set_shape(ov::Shape{ num_elems }); | ||
auto result = outputs[0].data<bool>(); | ||
|
||
for (size_t idx = 0; idx < num_elems; ++idx) { | ||
// handle indices due to broadcasting case | ||
size_t idx1 = (idx < num_elems1) ? idx : 0; | ||
size_t idx2 = (idx < num_elems2) ? idx : 0; | ||
|
||
std::vector<uint8_t> op1(chars1 + begins1[idx1], chars1 + ends1[idx1]); | ||
std::vector<uint8_t> op2(chars2 + begins2[idx2], chars2 + ends2[idx2]); | ||
if (op1 == op2) { | ||
result[idx] = true; | ||
} | ||
else { | ||
result[idx] = false; | ||
} | ||
} | ||
return true; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <openvino/op/op.hpp> | ||
|
||
// EqualStr compares two unpacked string tensors and outputs 1D boolean tensor | ||
// The operation is only applicable if output shape of string tensor corresponds to 1D tensor | ||
class EqualStr : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("EqualStr"); | ||
|
||
EqualStr() = default; | ||
|
||
EqualStr(ov::OutputVector inputs) | ||
: ov::op::Op(inputs) { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override { | ||
auto result = std::make_shared<EqualStr>(inputs); | ||
return result; | ||
} | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override { | ||
return true; | ||
} | ||
|
||
bool has_evaluate() const override { | ||
return true; | ||
} | ||
|
||
bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <openvino/op/constant.hpp> | ||
|
||
#include "ragged_to_ragged.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace ov; | ||
using op::v0::Constant; | ||
|
||
void RaggedToRagged::validate_and_infer_types() { | ||
OPENVINO_ASSERT(get_input_size() == 2); | ||
|
||
auto rowids_type = this->get_input_element_type(0); | ||
auto first_dim_size_type = this->get_input_element_type(1); | ||
|
||
OPENVINO_ASSERT(rowids_type == element::i32, "Expected an i32 rowids tensor ragged representation."); | ||
OPENVINO_ASSERT(first_dim_size_type == element::i32, "Expected an i32 first dim size tensor ragged representation."); | ||
|
||
set_output_type(0, get_input_element_type(0), PartialShape({ Dimension::dynamic() })); | ||
set_output_type(1, get_input_element_type(0), PartialShape({ Dimension::dynamic() })); | ||
} | ||
|
||
|
||
bool RaggedToRagged::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const { | ||
auto rowids = inputs[0].data<const int32_t>(); | ||
auto rowids_size = static_cast<int32_t>(inputs[0].get_size()); | ||
auto first_dim_size = inputs[1].data<const int32_t>(); | ||
|
||
const uint64_t batch_size = static_cast<uint64_t>(first_dim_size[0]); | ||
outputs[0].set_shape(ov::Shape{ batch_size }); | ||
outputs[1].set_shape(ov::Shape{ batch_size }); | ||
|
||
auto begins = outputs[0].data<int32_t>(); | ||
auto ends = outputs[1].data<int32_t>(); | ||
|
||
// prev_row_id_idx stores value idx for previous row | ||
int32_t prev_row_id_idx = 0; | ||
// prev_row_id stores row id for previous row | ||
int32_t prev_row_id = -1; | ||
for (int32_t rowids_idx = 0; rowids_idx < rowids_size; ++rowids_idx) { | ||
int32_t curr_row_id = rowids[rowids_idx]; | ||
OPENVINO_ASSERT(0 <= curr_row_id, "row id must be non-negative"); | ||
if (curr_row_id >= batch_size) { | ||
break; | ||
} | ||
|
||
if (prev_row_id != curr_row_id) { | ||
if (prev_row_id != -1) { | ||
begins[prev_row_id] = prev_row_id_idx; | ||
ends[prev_row_id] = rowids_idx; | ||
} | ||
|
||
int32_t idx = prev_row_id + 1; | ||
while (idx < curr_row_id) { | ||
begins[idx] = rowids_idx; | ||
ends[idx] = rowids_idx; | ||
++idx; | ||
} | ||
|
||
prev_row_id_idx = rowids_idx; | ||
prev_row_id = curr_row_id; | ||
} | ||
|
||
if (rowids_idx + 1 == rowids_size) { | ||
begins[curr_row_id] = prev_row_id_idx; | ||
ends[curr_row_id] = rowids_size; | ||
prev_row_id = curr_row_id; | ||
prev_row_id_idx = rowids_size; | ||
} | ||
} | ||
|
||
prev_row_id = (prev_row_id < 0) ? 0 : prev_row_id + 1; | ||
for (int32_t batch_idx = prev_row_id; batch_idx < batch_size; ++batch_idx) { | ||
begins[batch_idx] = prev_row_id_idx; | ||
ends[batch_idx] = prev_row_id_idx; | ||
} | ||
|
||
return true; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <openvino/op/op.hpp> | ||
|
||
// Operation that transforms ragged tensor from rowids format to begins-ends format | ||
// value_rowids just defines to which row each value from values vector belongs | ||
// for example, rowids = [0, 0, 2, 3, 3, 3] and first_dims_size = 5 | ||
// it corresponds to ragged tensor with | ||
// begins = [0, 2, 2, 3, 6] | ||
// ends = [2, 2, 3, 6, 6] | ||
class RaggedToRagged : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("RaggedToRagged"); | ||
|
||
RaggedToRagged() = default; | ||
|
||
RaggedToRagged(const ov::OutputVector& arguments) : | ||
ov::op::Op(arguments) { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override { | ||
return std::make_shared<RaggedToRagged>(inputs); | ||
} | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override { | ||
return true; | ||
} | ||
|
||
bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; | ||
|
||
bool has_evaluate() const override { | ||
return true; | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.